mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
improve search more
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import os, random, traceback
|
||||
import time
|
||||
import itertools
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
@@ -13,10 +14,12 @@ Interventions = Enum("Interventions", ["SWAP", "UPCAST"])
|
||||
def get_interventions(k):
|
||||
p1 = [(Interventions.SWAP, x) for x in itertools.combinations(range(k.first_reduce), 2)]
|
||||
p2 = [(Interventions.SWAP, x) for x in itertools.combinations(range(k.first_reduce, k.shape_len), 2)]
|
||||
p3 = []
|
||||
p3 = [(Interventions.UPCAST, None)] if max(st.shape[-1] for st in k.sts) < 32 else []
|
||||
for up_axis in range(k.shape_len):
|
||||
for amount in [2,4,8]:
|
||||
if all(st.shape[up_axis] == 1 for st in k.sts): continue
|
||||
max_up = max(st.shape[up_axis] for st in k.sts)
|
||||
if max_up == 1: continue
|
||||
for amount in list(set([2,4,8,max_up])):
|
||||
if amount >= 32: continue
|
||||
if not all(st.shape[up_axis] == 1 or st.shape[up_axis]%amount == 0 for st in k.sts): continue
|
||||
p3.append((Interventions.UPCAST, (up_axis, amount)))
|
||||
return p1+p2+p3
|
||||
@@ -29,26 +32,31 @@ def apply_intervention(k, typ, dat):
|
||||
new_order[a1], new_order[a2] = new_order[a2], new_order[a1]
|
||||
k.reshape_and_permute(None, new_order)
|
||||
elif typ == Interventions.UPCAST:
|
||||
# upcast
|
||||
up_axis, amount = dat[0], dat[1]
|
||||
# no change, we added a dimension
|
||||
k.reshape_and_permute(
|
||||
lambda x: list(x[0:up_axis]) + ([x[up_axis]//amount, amount] if x[up_axis] > 1 else [1,1]) + list(x[up_axis+1:]),
|
||||
[i for i in range(k.shape_len+1) if i != up_axis+1] + [up_axis+1])
|
||||
if dat is not None:
|
||||
# upcast
|
||||
up_axis, amount = dat[0], dat[1]
|
||||
# no change, we added a dimension
|
||||
k.reshape_and_permute(
|
||||
lambda x: list(x[0:up_axis]) + ([x[up_axis]//amount, amount] if x[up_axis] > 1 else [1,1]) + list(x[up_axis+1:]),
|
||||
[i for i in range(k.shape_len+1) if i != up_axis+1] + [up_axis+1])
|
||||
# drop the last dimension
|
||||
k.upcast()
|
||||
k.simplify_ones()
|
||||
|
||||
def run_and_time(k):
|
||||
try:
|
||||
prog = k.codegen()
|
||||
ret = []
|
||||
for i in range(3):
|
||||
e = prog(*k.bufs)
|
||||
CL.cl_queue.finish()
|
||||
ret.append((e.profile.end - e.profile.start) * OSX_TIMING_RATIO)
|
||||
return min(ret)
|
||||
except Exception:
|
||||
return float('inf')
|
||||
prog = k.codegen()
|
||||
ret = []
|
||||
for i in range(3):
|
||||
t1 = time.monotonic_ns()
|
||||
e = prog(*k.bufs)
|
||||
e.wait()
|
||||
t4 = time.monotonic_ns()
|
||||
t2, t3 = e.profile.start * OSX_TIMING_RATIO, e.profile.end * OSX_TIMING_RATIO
|
||||
#print(*[f"{(x-t1)*1e-3:7.2f} us" for x in [t1, t2, t3, t4]]) # TODO: this may be wrong on non OS X
|
||||
#assert t1 < t2 < t3 < t4, "timings not in order"
|
||||
ret.append(t3-t2)
|
||||
#ret.append(t4-t1)
|
||||
return min(ret)
|
||||
|
||||
def search_one(ast, winning_interventions):
|
||||
k = CLASTKernel(ast)
|
||||
@@ -57,11 +65,14 @@ def search_one(ast, winning_interventions):
|
||||
options = [(run_and_time(k), None, 0.9)]
|
||||
print(f"{options[-1][1]} : {options[-1][0]*1e-3:.2f}")
|
||||
for int in ints:
|
||||
k = CLASTKernel(ast)
|
||||
for w in winning_interventions: apply_intervention(k, *w)
|
||||
apply_intervention(k, *int)
|
||||
options.append((run_and_time(k), int, 1.0))
|
||||
print(f"{options[-1][1]} : {options[-1][0]*1e-3:.2f}")
|
||||
try:
|
||||
k = CLASTKernel(ast)
|
||||
for w in winning_interventions: apply_intervention(k, *w)
|
||||
apply_intervention(k, *int)
|
||||
options.append((run_and_time(k), int, 1.0))
|
||||
print(f"{options[-1][1]} : {options[-1][0]*1e-3:.2f}")
|
||||
except Exception:
|
||||
print(int, "FAILED")
|
||||
options = sorted(options, key=lambda x: x[0]*x[2])
|
||||
return options[0]
|
||||
|
||||
@@ -84,8 +95,9 @@ def search(ast):
|
||||
k = CLASTKernel(ast)
|
||||
for w in winning_interventions: apply_intervention(k, *w)
|
||||
k.codegen()(*k.bufs)
|
||||
#k.print()
|
||||
test_ast(k)
|
||||
print(f"improved from {baseline/1e6:.2f} ms to {best_time/1e6:.2f} ms, a {baseline/best_time:.2f}x speedup")
|
||||
print(f"improved from {baseline/1e6:.2f} ms to {best_time/1e6:.2f} ms, a {baseline/best_time:.2f}x speedup @ {k.info.flops/best_time:.2f} GFLOPS")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if int(os.getenv("OP", "0")) == 1:
|
||||
|
||||
Reference in New Issue
Block a user