mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
search with SHIFT, REDUCE
This commit is contained in:
@@ -10,7 +10,7 @@ from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel, CL
|
||||
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
|
||||
from test.lib_test_ast import test_ast
|
||||
|
||||
Interventions = Enum("Interventions", ["SWAP", "UPCAST"])
|
||||
Interventions = Enum("Interventions", ["SWAP", "UPCAST", "SHIFT", "REDUCE"])
|
||||
def get_interventions(k):
|
||||
p1 = [(Interventions.SWAP, x) for x in itertools.combinations(range(k.first_reduce), 2)]
|
||||
p2 = [(Interventions.SWAP, x) for x in itertools.combinations(range(k.first_reduce, k.shape_len), 2)]
|
||||
@@ -18,11 +18,19 @@ def get_interventions(k):
|
||||
for up_axis in range(k.shape_len):
|
||||
max_up = max(st.shape[up_axis] for st in k.sts)
|
||||
if max_up == 1: continue
|
||||
for amount in list(set([2,4,8,max_up])):
|
||||
for amount in sorted(list(set([2,4,8,max_up]))):
|
||||
if amount >= 32: continue
|
||||
if not all(st.shape[up_axis] == 1 or st.shape[up_axis]%amount == 0 for st in k.sts): continue
|
||||
p3.append((Interventions.UPCAST, (up_axis, amount)))
|
||||
return p1+p2+p3
|
||||
p4 = []
|
||||
for up_axis in range(1,k.first_reduce):
|
||||
for amount in [4,8,16,32]:
|
||||
if k.sts[0].shape[up_axis] % amount == 0:
|
||||
p4.append((Interventions.SHIFT, (up_axis, amount, True)))
|
||||
p4.append((Interventions.SHIFT, (up_axis, amount, False)))
|
||||
max_up = max(st.shape[k.first_reduce] for st in k.sts)
|
||||
p5 = [(Interventions.REDUCE, (max_up,))]
|
||||
return p1+p2+p3+p4+p5
|
||||
|
||||
def apply_intervention(k, typ, dat):
|
||||
if typ == Interventions.SWAP:
|
||||
@@ -41,12 +49,20 @@ def apply_intervention(k, typ, dat):
|
||||
[i for i in range(k.shape_len+1) if i != up_axis+1] + [up_axis+1])
|
||||
# drop the last dimension
|
||||
k.upcast()
|
||||
elif typ == Interventions.SHIFT:
|
||||
up_axis, amount, flip = dat[0], dat[1], dat[2]
|
||||
k.reshape_and_permute(
|
||||
lambda x: list(x[0:up_axis]) + (([amount, x[up_axis]//amount] if flip else [x[up_axis]//amount, amount]) if x[up_axis] > 1 else [1,1]) + list(x[up_axis+1:]),
|
||||
[up_axis] + [i for i in range(k.shape_len+1) if i != up_axis])
|
||||
elif typ == Interventions.REDUCE:
|
||||
k.group_for_reduce.append(dat[0])
|
||||
k.simplify_ones()
|
||||
k.simplify_merge_adjacent()
|
||||
|
||||
def run_and_time(k):
|
||||
prog = k.codegen()
|
||||
ret = []
|
||||
for i in range(3):
|
||||
for i in range(5):
|
||||
t1 = time.monotonic_ns()
|
||||
e = prog(*k.bufs)
|
||||
e.wait()
|
||||
|
||||
Reference in New Issue
Block a user