search with SHIFT, REDUCE

This commit is contained in:
George Hotz
2023-01-29 02:42:20 -08:00
parent 87879cf4b6
commit 45c0aa6e2d

View File

@@ -10,7 +10,7 @@ from tinygrad.llops.ops_gpu import GPUBuffer, CLASTKernel, CL
from tinygrad.runtime.opencl import OSX_TIMING_RATIO
from test.lib_test_ast import test_ast
Interventions = Enum("Interventions", ["SWAP", "UPCAST"])
Interventions = Enum("Interventions", ["SWAP", "UPCAST", "SHIFT", "REDUCE"])
def get_interventions(k):
p1 = [(Interventions.SWAP, x) for x in itertools.combinations(range(k.first_reduce), 2)]
p2 = [(Interventions.SWAP, x) for x in itertools.combinations(range(k.first_reduce, k.shape_len), 2)]
@@ -18,11 +18,19 @@ def get_interventions(k):
for up_axis in range(k.shape_len):
max_up = max(st.shape[up_axis] for st in k.sts)
if max_up == 1: continue
for amount in list(set([2,4,8,max_up])):
for amount in sorted(list(set([2,4,8,max_up]))):
if amount >= 32: continue
if not all(st.shape[up_axis] == 1 or st.shape[up_axis]%amount == 0 for st in k.sts): continue
p3.append((Interventions.UPCAST, (up_axis, amount)))
return p1+p2+p3
p4 = []
for up_axis in range(1,k.first_reduce):
for amount in [4,8,16,32]:
if k.sts[0].shape[up_axis] % amount == 0:
p4.append((Interventions.SHIFT, (up_axis, amount, True)))
p4.append((Interventions.SHIFT, (up_axis, amount, False)))
max_up = max(st.shape[k.first_reduce] for st in k.sts)
p5 = [(Interventions.REDUCE, (max_up,))]
return p1+p2+p3+p4+p5
def apply_intervention(k, typ, dat):
if typ == Interventions.SWAP:
@@ -41,12 +49,20 @@ def apply_intervention(k, typ, dat):
[i for i in range(k.shape_len+1) if i != up_axis+1] + [up_axis+1])
# drop the last dimension
k.upcast()
elif typ == Interventions.SHIFT:
up_axis, amount, flip = dat[0], dat[1], dat[2]
k.reshape_and_permute(
lambda x: list(x[0:up_axis]) + (([amount, x[up_axis]//amount] if flip else [x[up_axis]//amount, amount]) if x[up_axis] > 1 else [1,1]) + list(x[up_axis+1:]),
[up_axis] + [i for i in range(k.shape_len+1) if i != up_axis])
elif typ == Interventions.REDUCE:
k.group_for_reduce.append(dat[0])
k.simplify_ones()
k.simplify_merge_adjacent()
def run_and_time(k):
prog = k.codegen()
ret = []
for i in range(3):
for i in range(5):
t1 = time.monotonic_ns()
e = prog(*k.bufs)
e.wait()