diff --git a/examples/handcode_resnet50_opt.py b/examples/handcode_resnet50_opt.py index cceab51942..b5a5943a88 100644 --- a/examples/handcode_resnet50_opt.py +++ b/examples/handcode_resnet50_opt.py @@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor from tinygrad.ops import LoadOps, Device, Compiled from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions -from tinygrad.helpers import ansilen, DEBUG, getenv +from tinygrad.helpers import ansilen, DEBUG, getenv, flatten from tinygrad.graph import print_tree from tinygrad.lazy import vars_from_ast from tinygrad.shape.symbolic import sym_infer @@ -52,20 +52,24 @@ if __name__ == "__main__": if lin.apply_tensor_cores(): lins.append(lin) - # try a greedy search - if getenv("GREEDY"): + # try a beam search + if getenv("BEAM"): lin = Linearizer(si.ast, device.linearizer_opts) if str(lin.ast) in global_db: for ao in global_db[str(lin.ast)]: lin.apply_opt(ao) else: + best_tm = float('inf') + beam = [lin] while 1: - acted_lins = get_linearizer_actions(lin) - timed_lins = {k:time_linearizer(v, rawbufs) for k,v in acted_lins.items()} - opts = sorted(timed_lins.items(), key=lambda x: x[1]) - if opts[0][0] == 0: break # we are done - lin = acted_lins[opts[0][0]] - if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", lin.colored_shape()) + acted_lins = flatten([get_linearizer_actions(lin).items() for lin in beam]) + timed_lins = [(v,time_linearizer(v, rawbufs)) for k,v in acted_lins if k != 0] + opts = sorted(timed_lins, key=lambda x: x[1]) + if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster + best_tm = opts[0][1] + beam = [x[0] for x in opts[:getenv("BEAM")]] + if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", beam[0].colored_shape()) + lin = beam[0] global_db[str(lin.ast)] = lin.applied_opts lins.append(lin) diff --git a/extra/optimization/get_action_space.py b/extra/optimization/get_action_space.py new file mode 100644 index 0000000000..ca3eaeb992 --- /dev/null +++ b/extra/optimization/get_action_space.py @@ -0,0 +1,15 @@ +from tqdm import tqdm +from extra.optimization.helpers import load_worlds, ast_str_to_lin +from tinygrad.codegen.search import actions + +if __name__ == "__main__": + ast_strs = load_worlds(False, False, False) + tactions = set() + for ast_str in tqdm(ast_strs): + lin = ast_str_to_lin(ast_str) + lin.hand_coded_optimizations() + for o in lin.applied_opts: + assert o in actions + tactions.add(o) + print(len(tactions)) + print(sorted(list(tactions))) diff --git a/extra/optimization/test_time_linearizer.py b/extra/optimization/test_time_linearizer.py new file mode 100644 index 0000000000..ab3d592034 --- /dev/null +++ b/extra/optimization/test_time_linearizer.py @@ -0,0 +1,21 @@ +from extra.optimization.helpers import load_worlds, ast_str_to_lin +from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions + +if __name__ == "__main__": + ast_strs = load_worlds() + for i, ast_str in enumerate(ast_strs): + lin = ast_str_to_lin(ast_str) + rawbufs = bufs_from_lin(lin) + test_tm = time_linearizer(lin, rawbufs) + if test_tm < 1e-2: continue + print(f"EXAMPLE {i}") + acted_lins = get_linearizer_actions(lin) + ok_avg, short_avg = 0, 0 + for k,v in acted_lins.items(): + tm1 = time_linearizer(v, rawbufs) + tm2 = time_linearizer(v, rawbufs) + tm3 = time_linearizer(v, rawbufs, False) + print(v.colored_shape(50), f"{tm1*1e3:10.2f} {tm2*1e3:10.2f} {tm3*1e3:10.2f} : {((tm1-tm2)/tm1)*100:5.2f}% vs {((tm1-tm3)/tm1)*100:5.2f}%") + ok_avg += (tm1-tm2)/tm1 + short_avg += (tm1-tm3)/tm1 + print(f"{ok_avg/len(acted_lins)*100:5.2f}% vs {short_avg/len(acted_lins)*100:5.2f}%") diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index ee7d7c83ab..7afe173a84 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape from enum import Enum, auto class OptOps(Enum): - UPCAST = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702 + UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702 def __lt__(self, x:OptOps): return self.value < x.value @dataclass(frozen=True, order=True) @@ -322,20 +322,25 @@ class OptimizedKernel(Kernel): def apply_opt(self, opt:Opt): self.applied_opts.append(opt) - assert self.full_shape[opt.axis] % opt.amt == 0, "no longer valid shift" + axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else 0) + assert self.full_shape[axis] % opt.amt == 0, "no longer valid shift" if opt.op == OptOps.LOCAL: # cyan - assert opt.axis < (self.first_reduce-self.local_dims), "can't local a local or reduce" - self.shift_to(opt.axis, opt.amt, insert_before=self.first_reduce) + assert axis < (self.first_reduce-self.local_dims), "can't local a local or reduce" + self.shift_to(axis, opt.amt, insert_before=self.first_reduce) self.local_dims += 1 elif opt.op == OptOps.GROUP: # green - self.shift_to(opt.axis, opt.amt, insert_before=self.first_reduce + len(self.group_for_reduce)) + self.shift_to(axis, opt.amt, insert_before=self.first_reduce + len(self.group_for_reduce)) self.group_for_reduce.append(opt.amt) elif opt.op == OptOps.GROUPTOP: # green - self.shift_to(opt.axis, opt.amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce)) + self.shift_to(axis, opt.amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce)) self.group_for_reduce.append(opt.amt) - elif opt.op == OptOps.UPCAST: # yellow (or purple if it's a reduce axis) - assert opt.axis < self.shape_len-self.upcasted, "can't upcasted already upcasted" - self.shift_to(opt.axis, opt.amt, insert_before=None if opt.axis < self.first_reduce else len(self.full_unupcasted_shape)) + elif opt.op == OptOps.UNROLL: # purple + assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted" + self.shift_to(axis, opt.amt, insert_before=len(self.full_unupcasted_shape)) + self.upcast() + elif opt.op == OptOps.UPCAST: # yellow + assert axis < self.first_reduce, "upcast is for non-reduce" + self.shift_to(axis, opt.amt, insert_before=None) self.upcast() self.simplify_ones() @@ -345,7 +350,8 @@ class OptimizedKernel(Kernel): if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType: assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: - self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) + if unit_stride_axes_mul_4[0] < self.first_reduce: self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) + else: self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4)) def hand_coded_optimizations(self): # if there's images in the earlybufs, we have to make an axis the 4 loading one @@ -432,14 +438,14 @@ class OptimizedKernel(Kernel): # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))): if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis - self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, s)) + self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, s)) # if it's small, upcast a second reduce dimension too if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int): - self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, s2)) + self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, s2)) else: for splits in [4]: if self.full_unupcasted_shape[-1]%splits == 0: - self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits)) + self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits)) break # if nothing at all is upcasted and it's easy to, do an upcast diff --git a/tinygrad/codegen/search.py b/tinygrad/codegen/search.py index 9ba27b68c1..f6d052d61d 100644 --- a/tinygrad/codegen/search.py +++ b/tinygrad/codegen/search.py @@ -9,20 +9,22 @@ from collections import defaultdict from tinygrad.codegen.optimizer import Opt, OptOps actions = [ - Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=5), Opt(op=OptOps.UPCAST, axis=0, amt=6), Opt(op=OptOps.UPCAST, axis=0, amt=7), Opt(op=OptOps.UPCAST, axis=0, amt=9), Opt(op=OptOps.UPCAST, axis=0, amt=10), Opt(op=OptOps.UPCAST, axis=0, amt=12), Opt(op=OptOps.UPCAST, axis=0, amt=24), - Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=5), Opt(op=OptOps.UPCAST, axis=1, amt=6), Opt(op=OptOps.UPCAST, axis=1, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=9), Opt(op=OptOps.UPCAST, axis=1, amt=10), Opt(op=OptOps.UPCAST, axis=1, amt=11), Opt(op=OptOps.UPCAST, axis=1, amt=12), Opt(op=OptOps.UPCAST, axis=1, amt=13), Opt(op=OptOps.UPCAST, axis=1, amt=14), Opt(op=OptOps.UPCAST, axis=1, amt=15), Opt(op=OptOps.UPCAST, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=17), Opt(op=OptOps.UPCAST, axis=1, amt=20), Opt(op=OptOps.UPCAST, axis=1, amt=21), Opt(op=OptOps.UPCAST, axis=1, amt=24), Opt(op=OptOps.UPCAST, axis=1, amt=25), Opt(op=OptOps.UPCAST, axis=1, amt=28), Opt(op=OptOps.UPCAST, axis=1, amt=32), - Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=5), Opt(op=OptOps.UPCAST, axis=2, amt=6), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=9), Opt(op=OptOps.UPCAST, axis=2, amt=10), Opt(op=OptOps.UPCAST, axis=2, amt=11), Opt(op=OptOps.UPCAST, axis=2, amt=12), Opt(op=OptOps.UPCAST, axis=2, amt=14), Opt(op=OptOps.UPCAST, axis=2, amt=15), Opt(op=OptOps.UPCAST, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=2, amt=17), Opt(op=OptOps.UPCAST, axis=2, amt=20), Opt(op=OptOps.UPCAST, axis=2, amt=21), Opt(op=OptOps.UPCAST, axis=2, amt=24), Opt(op=OptOps.UPCAST, axis=2, amt=25), Opt(op=OptOps.UPCAST, axis=2, amt=27), Opt(op=OptOps.UPCAST, axis=2, amt=28), Opt(op=OptOps.UPCAST, axis=2, amt=30), Opt(op=OptOps.UPCAST, axis=2, amt=31), Opt(op=OptOps.UPCAST, axis=2, amt=32), - Opt(op=OptOps.UPCAST, axis=3, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=5), Opt(op=OptOps.UPCAST, axis=3, amt=6), Opt(op=OptOps.UPCAST, axis=3, amt=7), Opt(op=OptOps.UPCAST, axis=3, amt=8), Opt(op=OptOps.UPCAST, axis=3, amt=9), Opt(op=OptOps.UPCAST, axis=3, amt=10), Opt(op=OptOps.UPCAST, axis=3, amt=11), Opt(op=OptOps.UPCAST, axis=3, amt=12), Opt(op=OptOps.UPCAST, axis=3, amt=13), Opt(op=OptOps.UPCAST, axis=3, amt=14), Opt(op=OptOps.UPCAST, axis=3, amt=15), Opt(op=OptOps.UPCAST, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=24), Opt(op=OptOps.UPCAST, axis=3, amt=26), Opt(op=OptOps.UPCAST, axis=3, amt=27), Opt(op=OptOps.UPCAST, axis=3, amt=28), Opt(op=OptOps.UPCAST, axis=3, amt=32), - Opt(op=OptOps.UPCAST, axis=4, amt=2), Opt(op=OptOps.UPCAST, axis=4, amt=3), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=4, amt=5), Opt(op=OptOps.UPCAST, axis=4, amt=6), Opt(op=OptOps.UPCAST, axis=4, amt=7), Opt(op=OptOps.UPCAST, axis=4, amt=8), Opt(op=OptOps.UPCAST, axis=4, amt=9), Opt(op=OptOps.UPCAST, axis=4, amt=10), Opt(op=OptOps.UPCAST, axis=4, amt=11), Opt(op=OptOps.UPCAST, axis=4, amt=12), Opt(op=OptOps.UPCAST, axis=4, amt=13), Opt(op=OptOps.UPCAST, axis=4, amt=14), Opt(op=OptOps.UPCAST, axis=4, amt=15), Opt(op=OptOps.UPCAST, axis=4, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=20), Opt(op=OptOps.UPCAST, axis=4, amt=24), Opt(op=OptOps.UPCAST, axis=4, amt=26), Opt(op=OptOps.UPCAST, axis=4, amt=27), Opt(op=OptOps.UPCAST, axis=4, amt=28), Opt(op=OptOps.UPCAST, axis=4, amt=30), Opt(op=OptOps.UPCAST, axis=4, amt=32), - Opt(op=OptOps.UPCAST, axis=5, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=3), Opt(op=OptOps.UPCAST, axis=5, amt=4), Opt(op=OptOps.UPCAST, axis=5, amt=5), Opt(op=OptOps.UPCAST, axis=5, amt=6), Opt(op=OptOps.UPCAST, axis=5, amt=7), Opt(op=OptOps.UPCAST, axis=5, amt=8), Opt(op=OptOps.UPCAST, axis=5, amt=9), Opt(op=OptOps.UPCAST, axis=5, amt=11), Opt(op=OptOps.UPCAST, axis=5, amt=13), Opt(op=OptOps.UPCAST, axis=5, amt=16), Opt(op=OptOps.UPCAST, axis=5, amt=24), Opt(op=OptOps.UPCAST, axis=5, amt=26), Opt(op=OptOps.UPCAST, axis=5, amt=27), - Opt(op=OptOps.UPCAST, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=6, amt=3), Opt(op=OptOps.UPCAST, axis=6, amt=4), Opt(op=OptOps.UPCAST, axis=6, amt=5), Opt(op=OptOps.UPCAST, axis=6, amt=6), Opt(op=OptOps.UPCAST, axis=6, amt=7), Opt(op=OptOps.UPCAST, axis=6, amt=13), Opt(op=OptOps.UPCAST, axis=6, amt=16), Opt(op=OptOps.UPCAST, axis=6, amt=24), Opt(op=OptOps.UPCAST, axis=6, amt=26), Opt(op=OptOps.UPCAST, axis=6, amt=27), Opt(op=OptOps.UPCAST, axis=6, amt=31), - Opt(op=OptOps.UPCAST, axis=7, amt=2), Opt(op=OptOps.UPCAST, axis=7, amt=3), Opt(op=OptOps.UPCAST, axis=7, amt=4), Opt(op=OptOps.UPCAST, axis=7, amt=7), + Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=5), Opt(op=OptOps.UPCAST, axis=0, amt=6), Opt(op=OptOps.UPCAST, axis=0, amt=7), + Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=5), Opt(op=OptOps.UPCAST, axis=1, amt=6), Opt(op=OptOps.UPCAST, axis=1, amt=7), + Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=5), Opt(op=OptOps.UPCAST, axis=2, amt=6), Opt(op=OptOps.UPCAST, axis=2, amt=7), + Opt(op=OptOps.UPCAST, axis=3, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=5), Opt(op=OptOps.UPCAST, axis=3, amt=6), Opt(op=OptOps.UPCAST, axis=3, amt=7), + Opt(op=OptOps.UPCAST, axis=4, amt=2), Opt(op=OptOps.UPCAST, axis=4, amt=3), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=4, amt=5), + Opt(op=OptOps.UPCAST, axis=5, amt=3), Opt(op=OptOps.UPCAST, axis=5, amt=4), + Opt(op=OptOps.UNROLL, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=3), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=5), Opt(op=OptOps.UNROLL, axis=0, amt=6), Opt(op=OptOps.UNROLL, axis=0, amt=7), Opt(op=OptOps.UNROLL, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=0, amt=9), Opt(op=OptOps.UNROLL, axis=0, amt=10), Opt(op=OptOps.UNROLL, axis=0, amt=11), Opt(op=OptOps.UNROLL, axis=0, amt=12), Opt(op=OptOps.UNROLL, axis=0, amt=13), Opt(op=OptOps.UNROLL, axis=0, amt=14), Opt(op=OptOps.UNROLL, axis=0, amt=15), Opt(op=OptOps.UNROLL, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=0, amt=17), Opt(op=OptOps.UNROLL, axis=0, amt=20), Opt(op=OptOps.UNROLL, axis=0, amt=21), Opt(op=OptOps.UNROLL, axis=0, amt=24), Opt(op=OptOps.UNROLL, axis=0, amt=25), Opt(op=OptOps.UNROLL, axis=0, amt=27), Opt(op=OptOps.UNROLL, axis=0, amt=28), Opt(op=OptOps.UNROLL, axis=0, amt=30), Opt(op=OptOps.UNROLL, axis=0, amt=31), Opt(op=OptOps.UNROLL, axis=0, amt=32), + Opt(op=OptOps.UNROLL, axis=1, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=3), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=1, amt=5), Opt(op=OptOps.UNROLL, axis=1, amt=6), Opt(op=OptOps.UNROLL, axis=1, amt=7), Opt(op=OptOps.UNROLL, axis=1, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=9), Opt(op=OptOps.UNROLL, axis=1, amt=10), Opt(op=OptOps.UNROLL, axis=1, amt=11), Opt(op=OptOps.UNROLL, axis=1, amt=16), Opt(op=OptOps.UNROLL, axis=1, amt=24), Opt(op=OptOps.UNROLL, axis=1, amt=26), Opt(op=OptOps.UNROLL, axis=1, amt=27), Opt(op=OptOps.UNROLL, axis=1, amt=28), Opt(op=OptOps.UNROLL, axis=1, amt=32), + Opt(op=OptOps.UNROLL, axis=2, amt=2), Opt(op=OptOps.UNROLL, axis=2, amt=3), Opt(op=OptOps.UNROLL, axis=2, amt=4), Opt(op=OptOps.UNROLL, axis=2, amt=5), Opt(op=OptOps.UNROLL, axis=2, amt=6), Opt(op=OptOps.UNROLL, axis=2, amt=7), Opt(op=OptOps.UNROLL, axis=2, amt=8), Opt(op=OptOps.UNROLL, axis=2, amt=9), Opt(op=OptOps.UNROLL, axis=2, amt=11), Opt(op=OptOps.UNROLL, axis=2, amt=12), Opt(op=OptOps.UNROLL, axis=2, amt=13), Opt(op=OptOps.UNROLL, axis=2, amt=14), Opt(op=OptOps.UNROLL, axis=2, amt=16), Opt(op=OptOps.UNROLL, axis=2, amt=24), Opt(op=OptOps.UNROLL, axis=2, amt=26), Opt(op=OptOps.UNROLL, axis=2, amt=27), Opt(op=OptOps.UNROLL, axis=2, amt=28), Opt(op=OptOps.UNROLL, axis=2, amt=30), Opt(op=OptOps.UNROLL, axis=2, amt=31), Opt(op=OptOps.UNROLL, axis=2, amt=32), + Opt(op=OptOps.UNROLL, axis=3, amt=2), Opt(op=OptOps.UNROLL, axis=3, amt=3), Opt(op=OptOps.UNROLL, axis=3, amt=11), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=3), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=3), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.LOCAL, axis=3, amt=2), Opt(op=OptOps.LOCAL, axis=3, amt=3), Opt(op=OptOps.LOCAL, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=8), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=4, amt=2), Opt(op=OptOps.LOCAL, axis=4, amt=3), Opt(op=OptOps.LOCAL, axis=4, amt=16), - Opt(op=OptOps.GROUP, axis=1, amt=4), Opt(op=OptOps.GROUP, axis=1, amt=8), + Opt(op=OptOps.GROUP, axis=1, amt=4), Opt(op=OptOps.GROUP, axis=1, amt=8), Opt(op=OptOps.GROUP, axis=2, amt=8), Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.GROUPTOP, axis=0, amt=256), Opt(op=OptOps.GROUPTOP, axis=1, amt=16), Opt(op=OptOps.GROUPTOP, axis=1, amt=256), Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256)] @@ -30,7 +32,7 @@ device:Compiled = cast(Compiled, Device[Device.DEFAULT]) # returns time in seconds logtm = open(getenv("LOGTM", ""),"a") if getenv("LOGTM", "") else None -def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, cnt=3, should_copy=True) -> float: +def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True) -> float: if should_copy: lin = deepcopy(lin) # TODO: remove the need for this var_vals = {k:k.min for k in vars_from_ast(lin.ast)} try: @@ -39,21 +41,22 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru real_global_size = prg.global_size[:] if allow_test_size: test_global_size = prg.global_size[:] - while prod(test_global_size) > 16384: + while prod(test_global_size) > max_global_size: for j in range(2,-1,-1): - if test_global_size[j] > 1: + if test_global_size[j] > 16: test_global_size[j] //= 2 break factor = prod(prg.global_size) / prod(test_global_size) prg.global_size = test_global_size + #print(real_global_size, test_global_size, factor) else: factor = 1 tms = [prg(rawbufs, var_vals, force_wait=True)*factor for _ in range(cnt)] prg.global_size = real_global_size except Exception: - print("FAILED") - print(lin.ast) - print(lin.applied_opts) + #print("FAILED") + #print(lin.ast) + #print(lin.applied_opts) tms = [float('inf')] if logtm: logtm.write(str((lin.ast, lin.applied_opts, tms))+"\n") return min(tms)