mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
tiny search cleanup (#3910)
* tiny search cleanup removed some `assert isinstance(dev, Compiled)` and lines * remove import
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import math, itertools
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType
|
||||
from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -87,8 +87,8 @@ class LinearizerOptions(NamedTuple):
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, *ast:LazyOp, opts:Optional[LinearizerOptions]=None):
|
||||
self.opts = opts or (device.compiler.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) and device.compiler is not None else
|
||||
LinearizerOptions(Device.DEFAULT))
|
||||
self.opts = opts if opts is not None else (device.compiler.linearizer_opts if (device:=Device[Device.DEFAULT]).compiler is not None else
|
||||
LinearizerOptions(Device.DEFAULT))
|
||||
assert all(op.op is BufferOps.STORE for op in ast), f"kernels must have stores as the output, got {ast}"
|
||||
assert len(set(op.arg.st.size for op in ast)) == 1, f"all outbufs should have the same size, got {[op.arg.st for op in ast]}"
|
||||
self.ast = ast
|
||||
|
||||
@@ -6,11 +6,10 @@ from tinygrad.ops import MemBuffer
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.kernel import KernelOptError
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.symbolic import sym_infer, Variable
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
||||
actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4] for axis in range(4)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(5)]
|
||||
@@ -56,11 +55,11 @@ def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=No
|
||||
et = time.perf_counter() - st
|
||||
return prog, lin.global_size, lin.local_size, lin.uops.vars(), len(lin.outbufs), et, len(lin.uops.uops)
|
||||
|
||||
def _try_compile_linearized_w_idx(x, compiler:Compiler):
|
||||
try: return (x[0], _compile_linearizer(compiler, x[1], "test"))
|
||||
def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler):
|
||||
try: return x[0], _compile_linearizer(compiler, x[1], "test")
|
||||
except Exception:
|
||||
if DEBUG >= 4: traceback.print_exc()
|
||||
return (x[0], None)
|
||||
return x[0], None
|
||||
|
||||
# workers should ignore ctrl c
|
||||
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
@@ -91,11 +90,10 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
|
||||
up, lcl = 1, 1
|
||||
for s,c in zip(lin2.full_shape, lin2.colors()):
|
||||
if c in {"magenta", "yellow"}: up *= s
|
||||
if c in {"cyan", "green", "white"}: lcl *= s
|
||||
elif c in {"cyan", "green", "white"}: lcl *= s
|
||||
if up > max_up or lcl > max_lcl: continue
|
||||
acted_lins[i+1] = lin2
|
||||
except KernelOptError:
|
||||
pass
|
||||
except KernelOptError: pass
|
||||
return acted_lins
|
||||
|
||||
beam_pool = None
|
||||
@@ -118,7 +116,6 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[lin.opts.device]
|
||||
assert isinstance(dev, Compiled)
|
||||
while not exiting:
|
||||
acted_lins: List[Linearizer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
||||
timed_lins: List[Tuple[Linearizer, float]] = []
|
||||
@@ -167,7 +164,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
|
||||
dev = Device[lin.opts.device]
|
||||
assert isinstance(dev, Compiled) and dev.compiler is not None
|
||||
assert dev.compiler is not None
|
||||
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
||||
lib, global_size, local_size, vars, outcount, _, _ = _compile_linearizer(dev.compiler, lin)
|
||||
|
||||
Reference in New Issue
Block a user