Files
tinygrad/extra/optimization/helpers.py
George Hotz 41bfeb2c1e start work on auto opt (#2034)
* start work on auto opt

* lin failure

* not beating hcopt

* greedy

* timing is fast

* codegen.search

* greedy search in handcode_opt

* track running gflops

* clean up those files

* no failure
2023-10-11 12:54:53 -07:00

59 lines
1.8 KiB
Python

# stuff needed to unpack a kernel
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable
inf, nan = float('inf'), float('nan')
# kernel unpacker
from tinygrad.codegen.linearizer import Linearizer
def ast_str_to_lin(ast_str): return Linearizer(eval(ast_str))
# load worlds
import random
from tinygrad.helpers import dedup
def load_worlds():
ast_strs = dedup(open("/tmp/sops").read().strip().split("\n"))
ast_strs = [x for x in ast_strs if "ReduceOps" in x and "dtypes.image" not in x and "Variable" not in x]
random.seed(1337)
random.shuffle(ast_strs)
return ast_strs
def assert_same_lin(l1, l2):
assert l1.colored_shape() == l2.colored_shape()
assert all(x==y for x,y in zip(l1.sts, l2.sts))
# get features
import math
from tinygrad.shape.symbolic import Node
MAX_DIMS = 16
def lin_to_feats(lin):
all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"]
lc = [all_colors.index(x) for x in lin.colors()]
#my_sts = dedup([(x.shape == lin.full_shape, x.real_strides()) for x in lin.sts[1:]])
# first, the full shape, including the colors
ret = []
for s,c in zip(lin.full_shape,lc):
if isinstance(s, Node):
ret.append(False)
ret += [0]*7
else:
ret.append(True)
ret.append(math.log2(s))
ret.append(min(33, s))
ret.append(s%2 == 0)
ret.append(s%3 == 0)
ret.append(s%4 == 0)
ret.append(s%8 == 0)
ret.append(s%16 == 0)
cc = [0]*7
cc[c] = 1
ret += cc
ret += [0] * (15*(MAX_DIMS-len(lin.full_shape)))
ret = [float(x) for x in ret]
assert len(ret) == 240, f"wrong len {len(ret)}"
return ret