diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 876162448b..2849905a4d 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -68,7 +68,7 @@ class LinearizerOptions(NamedTuple): class Kernel: def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None): - self.opts = opts or (device.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else LinearizerOptions()) + self.opts = opts or (device.linearizer_opts if isinstance(device:=Device[Device.DEFAULT], Compiled) else LinearizerOptions(Device.DEFAULT)) self.ast = ast assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}" diff --git a/tinygrad/device.py b/tinygrad/device.py index 1d523ceca7..5e207d4261 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -319,7 +319,6 @@ class Compiled: kb = Linearizer(ast, self.linearizer_opts) kb.required_optimizations() from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin - # TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))))) timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2]) diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 4644777acf..706d6af3da 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -1,5 +1,5 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable -import itertools, random, math, time, multiprocessing, traceback, signal +import itertools, functools, random, math, time, multiprocessing, traceback, signal from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner from tinygrad.ops import MemBuffer, LazyOp from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name @@ -47,8 +47,8 @@ def _compile_linearizer(rdev:Compiled, lin:Linearizer, name:Optional[str]=None) src = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping return rdev.compiler(src), lin.global_size, lin.local_size -def _try_compile_linearized_w_idx(x): - try: return (x[0], _compile_linearizer(cast(Compiled, Device[Device.DEFAULT]), x[1], "test")) +def _try_compile_linearized_w_idx(x, device:str): + try: return (x[0], _compile_linearizer(cast(Compiled, Device[device]), x[1], "test")) except Exception: if DEBUG >= 4: traceback.print_exc() return (x[0], None) @@ -65,7 +65,7 @@ def bufs_from_lin(lin:Linearizer) -> List[Buffer]: rawbufs:List[Optional[Buffer]] = [None]*len(bufsts) for k,lx in bufsts.items(): buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx) - rawbufs[k] = Buffer(Device.DEFAULT, buf_size, lx[0].dtype) + rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype) assert all(r is not None for r in rawbufs) return cast(List[Buffer], rawbufs) @@ -89,7 +89,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz return acted_lins def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer: - key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT} + key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device} if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1: ret = lin.copy() for o in val[len(lin.applied_opts):]: ret.apply_opt(o) @@ -98,18 +98,19 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea beam: List[Tuple[Linearizer, float]] = [] seen_libs = set() - default_parallel = 1 if Device.DEFAULT in {"CUDA", "HIP"} else 0 + default_parallel = 1 if lin.opts.device in {"CUDA", "HIP"} else 0 pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker) if getenv("PARALLEL", default_parallel) else None try: var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()} exiting, st = False, time.perf_counter() - dev = Device[Device.DEFAULT] + dev = Device[lin.opts.device] assert isinstance(dev, Compiled) while not exiting: acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin] timed_lins: List[Tuple[Linearizer, float]] = [] - for i,proc in (pool.imap_unordered(_try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(_try_compile_linearized_w_idx, enumerate(acted_lins))): # noqa: E501 + _compile_fn = functools.partial(_try_compile_linearized_w_idx, device=lin.opts.device) + for i,proc in (pool.imap_unordered(_compile_fn, enumerate(acted_lins)) if pool is not None else map(_compile_fn, enumerate(acted_lins))): if proc is None: continue lib, global_size, local_size = proc if lib in seen_libs: continue @@ -146,10 +147,10 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff return ret[1] def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501 - key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501 + key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device} # noqa: E501 if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) - dev = Device[Device.DEFAULT] + dev = Device[lin.opts.device] assert isinstance(dev, Compiled) var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()} diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 1f3752dc4c..64c7e8981d 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -23,4 +23,4 @@ class ClangProgram: def __call__(self, *bufs, vals=(), wait=False): return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait) renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict")) -ClangDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram) +ClangDevice = Compiled(MallocAllocator, LinearizerOptions("CLANG", supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 5ad548d88a..cd23c1fc11 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -85,7 +85,7 @@ class CUDADevice(Compiled): from tinygrad.runtime.graph.cuda import CUDAGraph super().__init__(CUDAAllocator(self) if not CUDACPU else MallocAllocator, - LinearizerOptions(supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]), + LinearizerOptions("CUDA", supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]), CUDARenderer, compile_cuda, functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None) def synchronize(self): if not CUDACPU: diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index fdef087612..db19e5ba35 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -93,7 +93,7 @@ class CLDevice(Compiled): if CLDevice.compiler_context is None: CLDevice.compiler_context = self self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, ctypes.byref(status)), status) self.pending_copyin: List[memoryview] = [] - super().__init__(CLAllocator(self), LinearizerOptions(), OpenCLRenderer, compile_cl, functools.partial(CLProgram, self)) + super().__init__(CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer, compile_cl, functools.partial(CLProgram, self)) def synchronize(self): check(cl.clFinish(self.queue)) self.pending_copyin.clear() diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index bd53e89e21..5f09b374a7 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -88,7 +88,7 @@ class HIPDevice(Compiled): if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() # noqa: E501 from tinygrad.runtime.graph.hip import HIPGraph - super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions(device="HIP"), HIPRenderer, + super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device), HIPGraph) def synchronize(self): check(hip.hipSetDevice(self.device)) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 3264f01e87..35c1f6bc58 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -62,5 +62,5 @@ class LLVMProgram: self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*len(bufs)), *([ctypes.c_int32]*len(vals)))(self.fxn) return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait) -LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), +LLVMDevice = Compiled(MallocAllocator, LinearizerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index d9709642a8..abf66f9a7f 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -80,7 +80,7 @@ class MetalDevice(Compiled): self.mtl_buffers_in_flight: List[Any] = [] self.mv_in_metal: List[memoryview] = [] from tinygrad.runtime.graph.metal import MetalGraph - super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer, + super().__init__(MetalAllocator(self), LinearizerOptions("METAL"), MetalRenderer, compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self)) def synchronize(self): for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()