lowerer is kernel [run_process_replay] (#5437)

This commit is contained in:
George Hotz
2024-07-12 18:50:55 -07:00
committed by GitHub
parent b8342fb085
commit 03c2dc8bd7
33 changed files with 215 additions and 213 deletions

View File

@@ -7,7 +7,7 @@ from tinygrad.nn.optim import Adam
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import getenv
# stuff needed to unpack a kernel
@@ -38,7 +38,7 @@ def dataset_from_cache(fn):
for f in tqdm(cur.fetchall()):
Xs,As = [], []
try:
lin = Lowerer(eval(f[0]))
lin = Kernel(eval(f[0]))
opts = pickle.loads(f[-1])
for o in opts:
Xs.append(lin_to_feats(lin, use_sts=True))

View File

@@ -13,7 +13,7 @@ inf, nan = float('inf'), float('nan')
from tinygrad.codegen.kernel import Opt, OptOps
# more stuff
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import actions
from extra.optimization.helpers import lin_to_feats
from extra.optimization.pretrain_valuenet import ValueNet
@@ -48,7 +48,7 @@ def dataset_from_cache(fn):
new_tm = min(opts_to_outcome[(ast,k)])
if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: continue
try:
lin = Lowerer(eval(ast))
lin = Kernel(eval(ast))
except Exception:
continue
for opt in k[:-1]: lin.apply_opt(opt)

View File

@@ -1,12 +1,12 @@
import random
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import actions
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import tqdm
tactions = set()
def test_rebuild(lin):
linr = Lowerer(lin.ast)
linr = Kernel(lin.ast)
for o in lin.applied_opts:
assert o in actions, f"{o} is not in actions"
tactions.add(o)

View File

@@ -9,12 +9,12 @@ from tinygrad.shape.symbolic import Variable, NumNode
inf, nan = float('inf'), float('nan')
# kernel unpacker
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
def ast_str_to_ast(ast_str:str) -> Tuple[LazyOp,...]: return LazyOp(MetaOps.SINK, val) if isinstance(val:=eval(ast_str), tuple) else val
def ast_str_to_lin(ast_str:str, opts=None): return Lowerer(ast_str_to_ast(ast_str), opts=opts)
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
def kern_str_to_lin(kern_str:str, opts=None):
(ast, applied_opts,) = eval(kern_str)
k = Lowerer(ast, opts=opts)
k = Kernel(ast, opts=opts)
for opt in applied_opts:
k.apply_opt(opt)
return k
@@ -44,7 +44,7 @@ from tinygrad.shape.symbolic import Node
MAX_DIMS = 16
MAX_BUFS = 9
def lin_to_feats(lin:Lowerer, use_sts=True):
def lin_to_feats(lin:Kernel, use_sts=True):
assert lin.shape_len < MAX_DIMS, "too many dims"
all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"]

View File

@@ -1,4 +1,4 @@
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tqdm import tqdm, trange
import math
import random
@@ -45,7 +45,7 @@ if __name__ == "__main__":
X,Y = [], []
for i,x in enumerate(tqdm(dset)):
ast, opts, tms = eval(x)
lin = Lowerer(ast)
lin = Kernel(ast)
for o in opts: lin.apply_opt(o)
if lin.shape_len >= MAX_DIMS: continue
if min(tms) == float('inf'): continue

View File

@@ -1,9 +1,9 @@
from typing import List, Tuple
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import get_linearizer_actions, actions
_net = None
def beam_q_estimate(beam:List[Tuple[Lowerer, float]]) -> List[Tuple[Lowerer, float]]:
def beam_q_estimate(beam:List[Tuple[Kernel, float]]) -> List[Tuple[Kernel, float]]:
global _net
if _net is None:
from tinygrad.nn.state import load_state_dict, safe_load

View File

@@ -4,7 +4,7 @@ from extra.optimization.helpers import ast_str_to_lin
from tinygrad import dtypes
from tinygrad.helpers import BEAM, getenv
from tinygrad.device import Device, Compiled
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin