mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use device from LinearizerOptions in kernel search (#3090)
* use device from LinearizerOptions in kernel search removed all Device.DEFAULT in search.py * pass device string for parallel pickle * device for interpreted backends in LinearizerOptions
This commit is contained in:
@@ -68,7 +68,7 @@ class LinearizerOptions(NamedTuple):
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
|
||||
self.opts = opts or (device.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else LinearizerOptions())
|
||||
self.opts = opts or (device.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else LinearizerOptions(Device.DEFAULT))
|
||||
self.ast = ast
|
||||
assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}"
|
||||
|
||||
|
||||
@@ -319,7 +319,6 @@ class Compiled:
|
||||
kb = Linearizer(ast, self.linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
|
||||
# TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions
|
||||
test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization
|
||||
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, random, math, time, multiprocessing, traceback, signal
|
||||
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
||||
from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner
|
||||
from tinygrad.ops import MemBuffer, LazyOp
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
@@ -47,8 +47,8 @@ def _compile_linearizer(rdev:Compiled, lin:Linearizer, name:Optional[str]=None)
|
||||
src = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
|
||||
return rdev.compiler(src), lin.global_size, lin.local_size
|
||||
|
||||
def _try_compile_linearized_w_idx(x):
|
||||
try: return (x[0], _compile_linearizer(cast(Compiled, Device[Device.DEFAULT]), x[1], "test"))
|
||||
def _try_compile_linearized_w_idx(x, device:str):
|
||||
try: return (x[0], _compile_linearizer(cast(Compiled, Device[device]), x[1], "test"))
|
||||
except Exception:
|
||||
if DEBUG >= 4: traceback.print_exc()
|
||||
return (x[0], None)
|
||||
@@ -65,7 +65,7 @@ def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
|
||||
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
|
||||
for k,lx in bufsts.items():
|
||||
buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
|
||||
rawbufs[k] = Buffer(Device.DEFAULT, buf_size, lx[0].dtype)
|
||||
rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
return cast(List[Buffer], rawbufs)
|
||||
|
||||
@@ -89,7 +89,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
|
||||
return acted_lins
|
||||
|
||||
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
||||
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
|
||||
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device}
|
||||
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
||||
ret = lin.copy()
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
@@ -98,18 +98,19 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
beam: List[Tuple[Linearizer, float]] = []
|
||||
seen_libs = set()
|
||||
|
||||
default_parallel = 1 if Device.DEFAULT in {"CUDA", "HIP"} else 0
|
||||
default_parallel = 1 if lin.opts.device in {"CUDA", "HIP"} else 0
|
||||
pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker) if getenv("PARALLEL", default_parallel) else None
|
||||
|
||||
try:
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[Device.DEFAULT]
|
||||
dev = Device[lin.opts.device]
|
||||
assert isinstance(dev, Compiled)
|
||||
while not exiting:
|
||||
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
||||
timed_lins: List[Tuple[Linearizer, float]] = []
|
||||
for i,proc in (pool.imap_unordered(_try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(_try_compile_linearized_w_idx, enumerate(acted_lins))): # noqa: E501
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, device=lin.opts.device)
|
||||
for i,proc in (pool.imap_unordered(_compile_fn, enumerate(acted_lins)) if pool is not None else map(_compile_fn, enumerate(acted_lins))):
|
||||
if proc is None: continue
|
||||
lib, global_size, local_size = proc
|
||||
if lib in seen_libs: continue
|
||||
@@ -146,10 +147,10 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
|
||||
return ret[1]
|
||||
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device} # noqa: E501
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
|
||||
dev = Device[Device.DEFAULT]
|
||||
dev = Device[lin.opts.device]
|
||||
assert isinstance(dev, Compiled)
|
||||
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
||||
|
||||
@@ -23,4 +23,4 @@ class ClangProgram:
|
||||
def __call__(self, *bufs, vals=(), wait=False): return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait)
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict"))
|
||||
ClangDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)
|
||||
ClangDevice = Compiled(MallocAllocator, LinearizerOptions("CLANG", supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)
|
||||
|
||||
@@ -85,7 +85,7 @@ class CUDADevice(Compiled):
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
super().__init__(CUDAAllocator(self) if not CUDACPU else MallocAllocator,
|
||||
LinearizerOptions(supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]),
|
||||
LinearizerOptions("CUDA", supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]),
|
||||
CUDARenderer, compile_cuda, functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None)
|
||||
def synchronize(self):
|
||||
if not CUDACPU:
|
||||
|
||||
@@ -93,7 +93,7 @@ class CLDevice(Compiled):
|
||||
if CLDevice.compiler_context is None: CLDevice.compiler_context = self
|
||||
self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, ctypes.byref(status)), status)
|
||||
self.pending_copyin: List[memoryview] = []
|
||||
super().__init__(CLAllocator(self), LinearizerOptions(), OpenCLRenderer, compile_cl, functools.partial(CLProgram, self))
|
||||
super().__init__(CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer, compile_cl, functools.partial(CLProgram, self))
|
||||
def synchronize(self):
|
||||
check(cl.clFinish(self.queue))
|
||||
self.pending_copyin.clear()
|
||||
|
||||
@@ -88,7 +88,7 @@ class HIPDevice(Compiled):
|
||||
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501
|
||||
|
||||
from tinygrad.runtime.graph.hip import HIPGraph
|
||||
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions(device="HIP"), HIPRenderer,
|
||||
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
|
||||
compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
|
||||
def synchronize(self):
|
||||
check(hip.hipSetDevice(self.device))
|
||||
|
||||
@@ -62,5 +62,5 @@ class LLVMProgram:
|
||||
self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*len(bufs)), *([ctypes.c_int32]*len(vals)))(self.fxn)
|
||||
return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait)
|
||||
|
||||
LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False),
|
||||
LLVMDevice = Compiled(MallocAllocator, LinearizerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False),
|
||||
uops_to_llvm_ir, compile_llvm, LLVMProgram)
|
||||
|
||||
@@ -80,7 +80,7 @@ class MetalDevice(Compiled):
|
||||
self.mtl_buffers_in_flight: List[Any] = []
|
||||
self.mv_in_metal: List[memoryview] = []
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer,
|
||||
super().__init__(MetalAllocator(self), LinearizerOptions("METAL"), MetalRenderer,
|
||||
compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
|
||||
def synchronize(self):
|
||||
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
||||
|
||||
Reference in New Issue
Block a user