move kernel to opt (#10899)

This commit is contained in:
George Hotz
2025-06-20 15:22:28 -07:00
committed by GitHub
parent bb0299b9e5
commit 92678e59ee
60 changed files with 106 additions and 106 deletions

View File

@@ -5,9 +5,9 @@ from tinygrad.nn import Linear
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import Adam
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import actions
from tinygrad.opt.search import actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin
from tinygrad.codegen.kernel import Kernel
from tinygrad.opt.kernel import Kernel
from tinygrad.helpers import getenv
# stuff needed to unpack a kernel
@@ -17,7 +17,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.uop.ops import Variable
inf, nan = float('inf'), float('nan')
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.opt.kernel import Opt, OptOps
INNER = 256
class PolicyNet:

View File

@@ -10,11 +10,11 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.uop.ops import Variable
inf, nan = float('inf'), float('nan')
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.opt.kernel import Opt, OptOps
# more stuff
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import actions
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.search import actions
from extra.optimization.helpers import lin_to_feats
from extra.optimization.pretrain_valuenet import ValueNet
from tinygrad.nn.optim import Adam

View File

@@ -1,8 +1,8 @@
import random
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import actions
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.opt.search import actions
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.helpers import tqdm
tactions = set()

View File

@@ -1,6 +1,6 @@
# stuff needed to unpack a kernel
from tinygrad import Variable
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.opt.kernel import Opt, OptOps
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
@@ -9,7 +9,7 @@ inf, nan = float('inf'), float('nan')
UOps = Ops
# kernel unpacker
from tinygrad.codegen.kernel import Kernel
from tinygrad.opt.kernel import Kernel
def ast_str_to_ast(ast_str:str) -> UOp: return eval(ast_str)
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
def kern_str_to_lin(kern_str:str, opts=None):
@@ -101,7 +101,7 @@ def lin_to_feats(lin:Kernel, use_sts=True):
return ret
from tinygrad.device import Device, Buffer
from tinygrad.engine.search import _ensure_buffer_alloc, _time_program
from tinygrad.opt.search import _ensure_buffer_alloc, _time_program
from tinygrad.helpers import to_function_name, CACHELEVEL, diskcache_get, diskcache_put
def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501

View File

@@ -1,4 +1,4 @@
from tinygrad.codegen.kernel import Kernel
from tinygrad.opt.kernel import Kernel
from tqdm import tqdm, trange
import math
import random
@@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.uop.ops import Variable
inf, nan = float('inf'), float('nan')
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.opt.kernel import Opt, OptOps
from extra.optimization.helpers import lin_to_feats, MAX_DIMS

View File

@@ -3,7 +3,7 @@ import numpy as np
import math, random
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import actions, bufs_from_lin, get_kernel_actions
from tinygrad.opt.search import actions, bufs_from_lin, get_kernel_actions
from tinygrad.nn.optim import Adam
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer

View File

@@ -1,6 +1,6 @@
from typing import List, Tuple
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import get_kernel_actions, actions
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.search import get_kernel_actions, actions
_net = None
def beam_q_estimate(beam:List[Tuple[Kernel, float]]) -> List[Tuple[Kernel, float]]:

View File

@@ -4,8 +4,8 @@ from extra.optimization.helpers import ast_str_to_lin, time_linearizer
from tinygrad import dtypes
from tinygrad.helpers import BEAM, getenv
from tinygrad.device import Device, Compiled
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import beam_search, bufs_from_lin
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.search import beam_search, bufs_from_lin
if __name__ == '__main__':

View File

@@ -6,8 +6,8 @@ from copy import deepcopy
from tinygrad.helpers import getenv, colored
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import bufs_from_lin, actions, get_kernel_actions
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.opt.search import bufs_from_lin, actions, get_kernel_actions
from tinygrad.opt.heuristic import hand_coded_optimizations
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.pretrain_valuenet import ValueNet

View File

@@ -1,5 +1,5 @@
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer
from tinygrad.engine.search import bufs_from_lin, get_kernel_actions
from tinygrad.opt.search import bufs_from_lin, get_kernel_actions
if __name__ == "__main__":
ast_strs = load_worlds()