use device from LinearizerOptions in kernel search (#3090)

* use device from LinearizerOptions in kernel search

removed all Device.DEFAULT in search.py

* pass device string for parallel pickle

* device for interpreted backends in LinearizerOptions
This commit is contained in:
chenyu
2024-01-11 14:46:03 -05:00
committed by GitHub
parent 93e3f952aa
commit 0fe6904351
9 changed files with 18 additions and 18 deletions

View File

@@ -68,7 +68,7 @@ class LinearizerOptions(NamedTuple):
class Kernel:
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None):
self.opts = opts or (device.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else LinearizerOptions())
self.opts = opts or (device.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else LinearizerOptions(Device.DEFAULT))
self.ast = ast
assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}"

View File

@@ -319,7 +319,6 @@ class Compiled:
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
# TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions
test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])

View File

@@ -1,5 +1,5 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random, math, time, multiprocessing, traceback, signal
import itertools, functools, random, math, time, multiprocessing, traceback, signal
from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner
from tinygrad.ops import MemBuffer, LazyOp
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
@@ -47,8 +47,8 @@ def _compile_linearizer(rdev:Compiled, lin:Linearizer, name:Optional[str]=None)
src = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
return rdev.compiler(src), lin.global_size, lin.local_size
def _try_compile_linearized_w_idx(x):
try: return (x[0], _compile_linearizer(cast(Compiled, Device[Device.DEFAULT]), x[1], "test"))
def _try_compile_linearized_w_idx(x, device:str):
try: return (x[0], _compile_linearizer(cast(Compiled, Device[device]), x[1], "test"))
except Exception:
if DEBUG >= 4: traceback.print_exc()
return (x[0], None)
@@ -65,7 +65,7 @@ def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
for k,lx in bufsts.items():
buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
rawbufs[k] = Buffer(Device.DEFAULT, buf_size, lx[0].dtype)
rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype)
assert all(r is not None for r in rawbufs)
return cast(List[Buffer], rawbufs)
@@ -89,7 +89,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
return acted_lins
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
@@ -98,18 +98,19 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
beam: List[Tuple[Linearizer, float]] = []
seen_libs = set()
default_parallel = 1 if Device.DEFAULT in {"CUDA", "HIP"} else 0
default_parallel = 1 if lin.opts.device in {"CUDA", "HIP"} else 0
pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker) if getenv("PARALLEL", default_parallel) else None
try:
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
exiting, st = False, time.perf_counter()
dev = Device[Device.DEFAULT]
dev = Device[lin.opts.device]
assert isinstance(dev, Compiled)
while not exiting:
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
timed_lins: List[Tuple[Linearizer, float]] = []
for i,proc in (pool.imap_unordered(_try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(_try_compile_linearized_w_idx, enumerate(acted_lins))): # noqa: E501
_compile_fn = functools.partial(_try_compile_linearized_w_idx, device=lin.opts.device)
for i,proc in (pool.imap_unordered(_compile_fn, enumerate(acted_lins)) if pool is not None else map(_compile_fn, enumerate(acted_lins))):
if proc is None: continue
lib, global_size, local_size = proc
if lib in seen_libs: continue
@@ -146,10 +147,10 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
return ret[1]
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device} # noqa: E501
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
dev = Device[Device.DEFAULT]
dev = Device[lin.opts.device]
assert isinstance(dev, Compiled)
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}

View File

@@ -23,4 +23,4 @@ class ClangProgram:
def __call__(self, *bufs, vals=(), wait=False): return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait)
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict"))
ClangDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)
ClangDevice = Compiled(MallocAllocator, LinearizerOptions("CLANG", supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)

View File

@@ -85,7 +85,7 @@ class CUDADevice(Compiled):
from tinygrad.runtime.graph.cuda import CUDAGraph
super().__init__(CUDAAllocator(self) if not CUDACPU else MallocAllocator,
LinearizerOptions(supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]),
LinearizerOptions("CUDA", supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]),
CUDARenderer, compile_cuda, functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None)
def synchronize(self):
if not CUDACPU:

View File

@@ -93,7 +93,7 @@ class CLDevice(Compiled):
if CLDevice.compiler_context is None: CLDevice.compiler_context = self
self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, ctypes.byref(status)), status)
self.pending_copyin: List[memoryview] = []
super().__init__(CLAllocator(self), LinearizerOptions(), OpenCLRenderer, compile_cl, functools.partial(CLProgram, self))
super().__init__(CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer, compile_cl, functools.partial(CLProgram, self))
def synchronize(self):
check(cl.clFinish(self.queue))
self.pending_copyin.clear()

View File

@@ -88,7 +88,7 @@ class HIPDevice(Compiled):
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501
from tinygrad.runtime.graph.hip import HIPGraph
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions(device="HIP"), HIPRenderer,
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
def synchronize(self):
check(hip.hipSetDevice(self.device))

View File

@@ -62,5 +62,5 @@ class LLVMProgram:
self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*len(bufs)), *([ctypes.c_int32]*len(vals)))(self.fxn)
return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait)
LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False),
LLVMDevice = Compiled(MallocAllocator, LinearizerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False),
uops_to_llvm_ir, compile_llvm, LLVMProgram)

View File

@@ -80,7 +80,7 @@ class MetalDevice(Compiled):
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer,
super().__init__(MetalAllocator(self), LinearizerOptions("METAL"), MetalRenderer,
compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()