improve search more

This commit is contained in:
George Hotz
2023-01-29 02:08:57 -08:00
parent f6bbd43cb8
commit 87879cf4b6

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python
import os, random, traceback
import time
import itertools
from enum import Enum
import numpy as np
@@ -13,10 +14,12 @@ Interventions = Enum("Interventions", ["SWAP", "UPCAST"])
def get_interventions(k):
p1 = [(Interventions.SWAP, x) for x in itertools.combinations(range(k.first_reduce), 2)]
p2 = [(Interventions.SWAP, x) for x in itertools.combinations(range(k.first_reduce, k.shape_len), 2)]
p3 = []
p3 = [(Interventions.UPCAST, None)] if max(st.shape[-1] for st in k.sts) < 32 else []
for up_axis in range(k.shape_len):
for amount in [2,4,8]:
if all(st.shape[up_axis] == 1 for st in k.sts): continue
max_up = max(st.shape[up_axis] for st in k.sts)
if max_up == 1: continue
for amount in list(set([2,4,8,max_up])):
if amount >= 32: continue
if not all(st.shape[up_axis] == 1 or st.shape[up_axis]%amount == 0 for st in k.sts): continue
p3.append((Interventions.UPCAST, (up_axis, amount)))
return p1+p2+p3
@@ -29,26 +32,31 @@ def apply_intervention(k, typ, dat):
new_order[a1], new_order[a2] = new_order[a2], new_order[a1]
k.reshape_and_permute(None, new_order)
elif typ == Interventions.UPCAST:
# upcast
up_axis, amount = dat[0], dat[1]
# no change, we added a dimension
k.reshape_and_permute(
lambda x: list(x[0:up_axis]) + ([x[up_axis]//amount, amount] if x[up_axis] > 1 else [1,1]) + list(x[up_axis+1:]),
[i for i in range(k.shape_len+1) if i != up_axis+1] + [up_axis+1])
if dat is not None:
# upcast
up_axis, amount = dat[0], dat[1]
# no change, we added a dimension
k.reshape_and_permute(
lambda x: list(x[0:up_axis]) + ([x[up_axis]//amount, amount] if x[up_axis] > 1 else [1,1]) + list(x[up_axis+1:]),
[i for i in range(k.shape_len+1) if i != up_axis+1] + [up_axis+1])
# drop the last dimension
k.upcast()
k.simplify_ones()
def run_and_time(k):
try:
prog = k.codegen()
ret = []
for i in range(3):
e = prog(*k.bufs)
CL.cl_queue.finish()
ret.append((e.profile.end - e.profile.start) * OSX_TIMING_RATIO)
return min(ret)
except Exception:
return float('inf')
prog = k.codegen()
ret = []
for i in range(3):
t1 = time.monotonic_ns()
e = prog(*k.bufs)
e.wait()
t4 = time.monotonic_ns()
t2, t3 = e.profile.start * OSX_TIMING_RATIO, e.profile.end * OSX_TIMING_RATIO
#print(*[f"{(x-t1)*1e-3:7.2f} us" for x in [t1, t2, t3, t4]]) # TODO: this may be wrong on non OS X
#assert t1 < t2 < t3 < t4, "timings not in order"
ret.append(t3-t2)
#ret.append(t4-t1)
return min(ret)
def search_one(ast, winning_interventions):
k = CLASTKernel(ast)
@@ -57,11 +65,14 @@ def search_one(ast, winning_interventions):
options = [(run_and_time(k), None, 0.9)]
print(f"{options[-1][1]} : {options[-1][0]*1e-3:.2f}")
for int in ints:
k = CLASTKernel(ast)
for w in winning_interventions: apply_intervention(k, *w)
apply_intervention(k, *int)
options.append((run_and_time(k), int, 1.0))
print(f"{options[-1][1]} : {options[-1][0]*1e-3:.2f}")
try:
k = CLASTKernel(ast)
for w in winning_interventions: apply_intervention(k, *w)
apply_intervention(k, *int)
options.append((run_and_time(k), int, 1.0))
print(f"{options[-1][1]} : {options[-1][0]*1e-3:.2f}")
except Exception:
print(int, "FAILED")
options = sorted(options, key=lambda x: x[0]*x[2])
return options[0]
@@ -84,8 +95,9 @@ def search(ast):
k = CLASTKernel(ast)
for w in winning_interventions: apply_intervention(k, *w)
k.codegen()(*k.bufs)
#k.print()
test_ast(k)
print(f"improved from {baseline/1e6:.2f} ms to {best_time/1e6:.2f} ms, a {baseline/best_time:.2f}x speedup")
print(f"improved from {baseline/1e6:.2f} ms to {best_time/1e6:.2f} ms, a {baseline/best_time:.2f}x speedup @ {k.info.flops/best_time:.2f} GFLOPS")
if __name__ == "__main__":
if int(os.getenv("OP", "0")) == 1: