with unroll, the action space goes from 161 -> 127 (#2060)

* with unroll, the action space goes from 161 -> 127

* more reliable instrumentation

* beam search is so op

* beam bugfix
This commit is contained in:
George Hotz
2023-10-12 20:52:23 -07:00
committed by GitHub
parent 6b7ac5c431
commit 6f1810af2d
5 changed files with 86 additions and 37 deletions

View File

@@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps, Device, Compiled
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions
from tinygrad.helpers import ansilen, DEBUG, getenv
from tinygrad.helpers import ansilen, DEBUG, getenv, flatten
from tinygrad.graph import print_tree
from tinygrad.lazy import vars_from_ast
from tinygrad.shape.symbolic import sym_infer
@@ -52,20 +52,24 @@ if __name__ == "__main__":
if lin.apply_tensor_cores():
lins.append(lin)
# try a greedy search
if getenv("GREEDY"):
# try a beam search
if getenv("BEAM"):
lin = Linearizer(si.ast, device.linearizer_opts)
if str(lin.ast) in global_db:
for ao in global_db[str(lin.ast)]:
lin.apply_opt(ao)
else:
best_tm = float('inf')
beam = [lin]
while 1:
acted_lins = get_linearizer_actions(lin)
timed_lins = {k:time_linearizer(v, rawbufs) for k,v in acted_lins.items()}
opts = sorted(timed_lins.items(), key=lambda x: x[1])
if opts[0][0] == 0: break # we are done
lin = acted_lins[opts[0][0]]
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", lin.colored_shape())
acted_lins = flatten([get_linearizer_actions(lin).items() for lin in beam])
timed_lins = [(v,time_linearizer(v, rawbufs)) for k,v in acted_lins if k != 0]
opts = sorted(timed_lins, key=lambda x: x[1])
if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster
best_tm = opts[0][1]
beam = [x[0] for x in opts[:getenv("BEAM")]]
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", beam[0].colored_shape())
lin = beam[0]
global_db[str(lin.ast)] = lin.applied_opts
lins.append(lin)

View File

@@ -0,0 +1,15 @@
from tqdm import tqdm
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.codegen.search import actions
if __name__ == "__main__":
ast_strs = load_worlds(False, False, False)
tactions = set()
for ast_str in tqdm(ast_strs):
lin = ast_str_to_lin(ast_str)
lin.hand_coded_optimizations()
for o in lin.applied_opts:
assert o in actions
tactions.add(o)
print(len(tactions))
print(sorted(list(tactions)))

View File

@@ -0,0 +1,21 @@
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions
if __name__ == "__main__":
ast_strs = load_worlds()
for i, ast_str in enumerate(ast_strs):
lin = ast_str_to_lin(ast_str)
rawbufs = bufs_from_lin(lin)
test_tm = time_linearizer(lin, rawbufs)
if test_tm < 1e-2: continue
print(f"EXAMPLE {i}")
acted_lins = get_linearizer_actions(lin)
ok_avg, short_avg = 0, 0
for k,v in acted_lins.items():
tm1 = time_linearizer(v, rawbufs)
tm2 = time_linearizer(v, rawbufs)
tm3 = time_linearizer(v, rawbufs, False)
print(v.colored_shape(50), f"{tm1*1e3:10.2f} {tm2*1e3:10.2f} {tm3*1e3:10.2f} : {((tm1-tm2)/tm1)*100:5.2f}% vs {((tm1-tm3)/tm1)*100:5.2f}%")
ok_avg += (tm1-tm2)/tm1
short_avg += (tm1-tm3)/tm1
print(f"{ok_avg/len(acted_lins)*100:5.2f}% vs {short_avg/len(acted_lins)*100:5.2f}%")

View File

@@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
from enum import Enum, auto
class OptOps(Enum):
UPCAST = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)
@@ -322,20 +322,25 @@ class OptimizedKernel(Kernel):
def apply_opt(self, opt:Opt):
self.applied_opts.append(opt)
assert self.full_shape[opt.axis] % opt.amt == 0, "no longer valid shift"
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else 0)
assert self.full_shape[axis] % opt.amt == 0, "no longer valid shift"
if opt.op == OptOps.LOCAL: # cyan
assert opt.axis < (self.first_reduce-self.local_dims), "can't local a local or reduce"
self.shift_to(opt.axis, opt.amt, insert_before=self.first_reduce)
assert axis < (self.first_reduce-self.local_dims), "can't local a local or reduce"
self.shift_to(axis, opt.amt, insert_before=self.first_reduce)
self.local_dims += 1
elif opt.op == OptOps.GROUP: # green
self.shift_to(opt.axis, opt.amt, insert_before=self.first_reduce + len(self.group_for_reduce))
self.shift_to(axis, opt.amt, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(opt.amt)
elif opt.op == OptOps.GROUPTOP: # green
self.shift_to(opt.axis, opt.amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
self.shift_to(axis, opt.amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(opt.amt)
elif opt.op == OptOps.UPCAST: # yellow (or purple if it's a reduce axis)
assert opt.axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
self.shift_to(opt.axis, opt.amt, insert_before=None if opt.axis < self.first_reduce else len(self.full_unupcasted_shape))
elif opt.op == OptOps.UNROLL: # purple
assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
self.shift_to(axis, opt.amt, insert_before=len(self.full_unupcasted_shape))
self.upcast()
elif opt.op == OptOps.UPCAST: # yellow
assert axis < self.first_reduce, "upcast is for non-reduce"
self.shift_to(axis, opt.amt, insert_before=None)
self.upcast()
self.simplify_ones()
@@ -345,7 +350,8 @@ class OptimizedKernel(Kernel):
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
if unit_stride_axes_mul_4[0] < self.first_reduce: self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else: self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
def hand_coded_optimizations(self):
# if there's images in the earlybufs, we have to make an axis the 4 loading one
@@ -432,14 +438,14 @@ class OptimizedKernel(Kernel):
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, s))
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, s))
# if it's small, upcast a second reduce dimension too
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, s2))
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, s2))
else:
for splits in [4]:
if self.full_unupcasted_shape[-1]%splits == 0:
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
break
# if nothing at all is upcasted and it's easy to, do an upcast

View File

@@ -9,20 +9,22 @@ from collections import defaultdict
from tinygrad.codegen.optimizer import Opt, OptOps
actions = [
Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=5), Opt(op=OptOps.UPCAST, axis=0, amt=6), Opt(op=OptOps.UPCAST, axis=0, amt=7), Opt(op=OptOps.UPCAST, axis=0, amt=9), Opt(op=OptOps.UPCAST, axis=0, amt=10), Opt(op=OptOps.UPCAST, axis=0, amt=12), Opt(op=OptOps.UPCAST, axis=0, amt=24),
Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=5), Opt(op=OptOps.UPCAST, axis=1, amt=6), Opt(op=OptOps.UPCAST, axis=1, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=9), Opt(op=OptOps.UPCAST, axis=1, amt=10), Opt(op=OptOps.UPCAST, axis=1, amt=11), Opt(op=OptOps.UPCAST, axis=1, amt=12), Opt(op=OptOps.UPCAST, axis=1, amt=13), Opt(op=OptOps.UPCAST, axis=1, amt=14), Opt(op=OptOps.UPCAST, axis=1, amt=15), Opt(op=OptOps.UPCAST, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=17), Opt(op=OptOps.UPCAST, axis=1, amt=20), Opt(op=OptOps.UPCAST, axis=1, amt=21), Opt(op=OptOps.UPCAST, axis=1, amt=24), Opt(op=OptOps.UPCAST, axis=1, amt=25), Opt(op=OptOps.UPCAST, axis=1, amt=28), Opt(op=OptOps.UPCAST, axis=1, amt=32),
Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=5), Opt(op=OptOps.UPCAST, axis=2, amt=6), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=9), Opt(op=OptOps.UPCAST, axis=2, amt=10), Opt(op=OptOps.UPCAST, axis=2, amt=11), Opt(op=OptOps.UPCAST, axis=2, amt=12), Opt(op=OptOps.UPCAST, axis=2, amt=14), Opt(op=OptOps.UPCAST, axis=2, amt=15), Opt(op=OptOps.UPCAST, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=2, amt=17), Opt(op=OptOps.UPCAST, axis=2, amt=20), Opt(op=OptOps.UPCAST, axis=2, amt=21), Opt(op=OptOps.UPCAST, axis=2, amt=24), Opt(op=OptOps.UPCAST, axis=2, amt=25), Opt(op=OptOps.UPCAST, axis=2, amt=27), Opt(op=OptOps.UPCAST, axis=2, amt=28), Opt(op=OptOps.UPCAST, axis=2, amt=30), Opt(op=OptOps.UPCAST, axis=2, amt=31), Opt(op=OptOps.UPCAST, axis=2, amt=32),
Opt(op=OptOps.UPCAST, axis=3, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=5), Opt(op=OptOps.UPCAST, axis=3, amt=6), Opt(op=OptOps.UPCAST, axis=3, amt=7), Opt(op=OptOps.UPCAST, axis=3, amt=8), Opt(op=OptOps.UPCAST, axis=3, amt=9), Opt(op=OptOps.UPCAST, axis=3, amt=10), Opt(op=OptOps.UPCAST, axis=3, amt=11), Opt(op=OptOps.UPCAST, axis=3, amt=12), Opt(op=OptOps.UPCAST, axis=3, amt=13), Opt(op=OptOps.UPCAST, axis=3, amt=14), Opt(op=OptOps.UPCAST, axis=3, amt=15), Opt(op=OptOps.UPCAST, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=24), Opt(op=OptOps.UPCAST, axis=3, amt=26), Opt(op=OptOps.UPCAST, axis=3, amt=27), Opt(op=OptOps.UPCAST, axis=3, amt=28), Opt(op=OptOps.UPCAST, axis=3, amt=32),
Opt(op=OptOps.UPCAST, axis=4, amt=2), Opt(op=OptOps.UPCAST, axis=4, amt=3), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=4, amt=5), Opt(op=OptOps.UPCAST, axis=4, amt=6), Opt(op=OptOps.UPCAST, axis=4, amt=7), Opt(op=OptOps.UPCAST, axis=4, amt=8), Opt(op=OptOps.UPCAST, axis=4, amt=9), Opt(op=OptOps.UPCAST, axis=4, amt=10), Opt(op=OptOps.UPCAST, axis=4, amt=11), Opt(op=OptOps.UPCAST, axis=4, amt=12), Opt(op=OptOps.UPCAST, axis=4, amt=13), Opt(op=OptOps.UPCAST, axis=4, amt=14), Opt(op=OptOps.UPCAST, axis=4, amt=15), Opt(op=OptOps.UPCAST, axis=4, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=20), Opt(op=OptOps.UPCAST, axis=4, amt=24), Opt(op=OptOps.UPCAST, axis=4, amt=26), Opt(op=OptOps.UPCAST, axis=4, amt=27), Opt(op=OptOps.UPCAST, axis=4, amt=28), Opt(op=OptOps.UPCAST, axis=4, amt=30), Opt(op=OptOps.UPCAST, axis=4, amt=32),
Opt(op=OptOps.UPCAST, axis=5, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=3), Opt(op=OptOps.UPCAST, axis=5, amt=4), Opt(op=OptOps.UPCAST, axis=5, amt=5), Opt(op=OptOps.UPCAST, axis=5, amt=6), Opt(op=OptOps.UPCAST, axis=5, amt=7), Opt(op=OptOps.UPCAST, axis=5, amt=8), Opt(op=OptOps.UPCAST, axis=5, amt=9), Opt(op=OptOps.UPCAST, axis=5, amt=11), Opt(op=OptOps.UPCAST, axis=5, amt=13), Opt(op=OptOps.UPCAST, axis=5, amt=16), Opt(op=OptOps.UPCAST, axis=5, amt=24), Opt(op=OptOps.UPCAST, axis=5, amt=26), Opt(op=OptOps.UPCAST, axis=5, amt=27),
Opt(op=OptOps.UPCAST, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=6, amt=3), Opt(op=OptOps.UPCAST, axis=6, amt=4), Opt(op=OptOps.UPCAST, axis=6, amt=5), Opt(op=OptOps.UPCAST, axis=6, amt=6), Opt(op=OptOps.UPCAST, axis=6, amt=7), Opt(op=OptOps.UPCAST, axis=6, amt=13), Opt(op=OptOps.UPCAST, axis=6, amt=16), Opt(op=OptOps.UPCAST, axis=6, amt=24), Opt(op=OptOps.UPCAST, axis=6, amt=26), Opt(op=OptOps.UPCAST, axis=6, amt=27), Opt(op=OptOps.UPCAST, axis=6, amt=31),
Opt(op=OptOps.UPCAST, axis=7, amt=2), Opt(op=OptOps.UPCAST, axis=7, amt=3), Opt(op=OptOps.UPCAST, axis=7, amt=4), Opt(op=OptOps.UPCAST, axis=7, amt=7),
Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=5), Opt(op=OptOps.UPCAST, axis=0, amt=6), Opt(op=OptOps.UPCAST, axis=0, amt=7),
Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=5), Opt(op=OptOps.UPCAST, axis=1, amt=6), Opt(op=OptOps.UPCAST, axis=1, amt=7),
Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=5), Opt(op=OptOps.UPCAST, axis=2, amt=6), Opt(op=OptOps.UPCAST, axis=2, amt=7),
Opt(op=OptOps.UPCAST, axis=3, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=5), Opt(op=OptOps.UPCAST, axis=3, amt=6), Opt(op=OptOps.UPCAST, axis=3, amt=7),
Opt(op=OptOps.UPCAST, axis=4, amt=2), Opt(op=OptOps.UPCAST, axis=4, amt=3), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=4, amt=5),
Opt(op=OptOps.UPCAST, axis=5, amt=3), Opt(op=OptOps.UPCAST, axis=5, amt=4),
Opt(op=OptOps.UNROLL, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=3), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=5), Opt(op=OptOps.UNROLL, axis=0, amt=6), Opt(op=OptOps.UNROLL, axis=0, amt=7), Opt(op=OptOps.UNROLL, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=0, amt=9), Opt(op=OptOps.UNROLL, axis=0, amt=10), Opt(op=OptOps.UNROLL, axis=0, amt=11), Opt(op=OptOps.UNROLL, axis=0, amt=12), Opt(op=OptOps.UNROLL, axis=0, amt=13), Opt(op=OptOps.UNROLL, axis=0, amt=14), Opt(op=OptOps.UNROLL, axis=0, amt=15), Opt(op=OptOps.UNROLL, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=0, amt=17), Opt(op=OptOps.UNROLL, axis=0, amt=20), Opt(op=OptOps.UNROLL, axis=0, amt=21), Opt(op=OptOps.UNROLL, axis=0, amt=24), Opt(op=OptOps.UNROLL, axis=0, amt=25), Opt(op=OptOps.UNROLL, axis=0, amt=27), Opt(op=OptOps.UNROLL, axis=0, amt=28), Opt(op=OptOps.UNROLL, axis=0, amt=30), Opt(op=OptOps.UNROLL, axis=0, amt=31), Opt(op=OptOps.UNROLL, axis=0, amt=32),
Opt(op=OptOps.UNROLL, axis=1, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=3), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=1, amt=5), Opt(op=OptOps.UNROLL, axis=1, amt=6), Opt(op=OptOps.UNROLL, axis=1, amt=7), Opt(op=OptOps.UNROLL, axis=1, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=9), Opt(op=OptOps.UNROLL, axis=1, amt=10), Opt(op=OptOps.UNROLL, axis=1, amt=11), Opt(op=OptOps.UNROLL, axis=1, amt=16), Opt(op=OptOps.UNROLL, axis=1, amt=24), Opt(op=OptOps.UNROLL, axis=1, amt=26), Opt(op=OptOps.UNROLL, axis=1, amt=27), Opt(op=OptOps.UNROLL, axis=1, amt=28), Opt(op=OptOps.UNROLL, axis=1, amt=32),
Opt(op=OptOps.UNROLL, axis=2, amt=2), Opt(op=OptOps.UNROLL, axis=2, amt=3), Opt(op=OptOps.UNROLL, axis=2, amt=4), Opt(op=OptOps.UNROLL, axis=2, amt=5), Opt(op=OptOps.UNROLL, axis=2, amt=6), Opt(op=OptOps.UNROLL, axis=2, amt=7), Opt(op=OptOps.UNROLL, axis=2, amt=8), Opt(op=OptOps.UNROLL, axis=2, amt=9), Opt(op=OptOps.UNROLL, axis=2, amt=11), Opt(op=OptOps.UNROLL, axis=2, amt=12), Opt(op=OptOps.UNROLL, axis=2, amt=13), Opt(op=OptOps.UNROLL, axis=2, amt=14), Opt(op=OptOps.UNROLL, axis=2, amt=16), Opt(op=OptOps.UNROLL, axis=2, amt=24), Opt(op=OptOps.UNROLL, axis=2, amt=26), Opt(op=OptOps.UNROLL, axis=2, amt=27), Opt(op=OptOps.UNROLL, axis=2, amt=28), Opt(op=OptOps.UNROLL, axis=2, amt=30), Opt(op=OptOps.UNROLL, axis=2, amt=31), Opt(op=OptOps.UNROLL, axis=2, amt=32),
Opt(op=OptOps.UNROLL, axis=3, amt=2), Opt(op=OptOps.UNROLL, axis=3, amt=3), Opt(op=OptOps.UNROLL, axis=3, amt=11),
Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=3), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=0, amt=32),
Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=3), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.LOCAL, axis=1, amt=16),
Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=16),
Opt(op=OptOps.LOCAL, axis=3, amt=2), Opt(op=OptOps.LOCAL, axis=3, amt=3), Opt(op=OptOps.LOCAL, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=8), Opt(op=OptOps.LOCAL, axis=3, amt=16),
Opt(op=OptOps.LOCAL, axis=4, amt=2), Opt(op=OptOps.LOCAL, axis=4, amt=3), Opt(op=OptOps.LOCAL, axis=4, amt=16),
Opt(op=OptOps.GROUP, axis=1, amt=4), Opt(op=OptOps.GROUP, axis=1, amt=8),
Opt(op=OptOps.GROUP, axis=1, amt=4), Opt(op=OptOps.GROUP, axis=1, amt=8), Opt(op=OptOps.GROUP, axis=2, amt=8),
Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.GROUPTOP, axis=0, amt=256),
Opt(op=OptOps.GROUPTOP, axis=1, amt=16), Opt(op=OptOps.GROUPTOP, axis=1, amt=256),
Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256)]
@@ -30,7 +32,7 @@ device:Compiled = cast(Compiled, Device[Device.DEFAULT])
# returns time in seconds
logtm = open(getenv("LOGTM", ""),"a") if getenv("LOGTM", "") else None
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, cnt=3, should_copy=True) -> float:
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True) -> float:
if should_copy: lin = deepcopy(lin) # TODO: remove the need for this
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
try:
@@ -39,21 +41,22 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
real_global_size = prg.global_size[:]
if allow_test_size:
test_global_size = prg.global_size[:]
while prod(test_global_size) > 16384:
while prod(test_global_size) > max_global_size:
for j in range(2,-1,-1):
if test_global_size[j] > 1:
if test_global_size[j] > 16:
test_global_size[j] //= 2
break
factor = prod(prg.global_size) / prod(test_global_size)
prg.global_size = test_global_size
#print(real_global_size, test_global_size, factor)
else:
factor = 1
tms = [prg(rawbufs, var_vals, force_wait=True)*factor for _ in range(cnt)]
prg.global_size = real_global_size
except Exception:
print("FAILED")
print(lin.ast)
print(lin.applied_opts)
#print("FAILED")
#print(lin.ast)
#print(lin.applied_opts)
tms = [float('inf')]
if logtm: logtm.write(str((lin.ast, lin.applied_opts, tms))+"\n")
return min(tms)