From 23b084e70a12c46dd678dfa00cad7ea2fb2aed4c Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 23 Jan 2024 20:34:56 -0800 Subject: [PATCH] add device name to device, all are constructed (#3221) --- tinygrad/device.py | 19 ++++++++----------- tinygrad/runtime/ops_clang.py | 6 ++++-- tinygrad/runtime/ops_cpu.py | 3 ++- tinygrad/runtime/ops_cuda.py | 2 +- tinygrad/runtime/ops_disk.py | 2 +- tinygrad/runtime/ops_gpu.py | 2 +- tinygrad/runtime/ops_hip.py | 2 +- tinygrad/runtime/ops_llvm.py | 2 +- tinygrad/runtime/ops_metal.py | 2 +- tinygrad/runtime/ops_torch.py | 3 ++- 10 files changed, 22 insertions(+), 21 deletions(-) diff --git a/tinygrad/device.py b/tinygrad/device.py index 66b1137817..c2656836c6 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -25,9 +25,7 @@ class _Device: @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none def __get_canonicalized_item(self, ix:str) -> Union[Interpreted, Compiled]: x = ix.split(":")[0].upper() - ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0] # noqa: E501 - if isinstance(ret, type): ret = ret(ix) - return ret + return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501 @functools.cached_property def DEFAULT(self) -> str: device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore @@ -189,8 +187,8 @@ class InterpretedASTRunner(JITRunner): return et class Interpreted: - def __init__(self, allocator: Allocator, fxn_for_op:Dict[Op, Callable]): - self.allocator, self.fxn_for_op = allocator, fxn_for_op + def __init__(self, device:str, allocator: Allocator, fxn_for_op:Dict[Op, Callable]): + self.dname, self.allocator, self.fxn_for_op = device, allocator, fxn_for_op self.synchronize, self.codegen, self.graph = lambda: None, None, None @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none @@ -305,21 +303,20 @@ class CompiledASTRunner(JITRunner): if local_size: lra['local_size'] = local_size et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2) if do_update_stats: update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, - lra=lra, device=rawbufs[0].device, first_run=self.first_run) + lra=lra, device=self.device.dname, first_run=self.first_run) self.first_run = False return et class Compiled: - def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, compiler_cachekey, runtime, graph=None): - self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph, self.compiler_cachekey = \ - allocator, linearizer_opts, renderer, compiler, runtime, graph, None if getenv("DISABLE_COMPILER_CACHE") else compiler_cachekey + def __init__(self, device:str, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, compiler_cachekey, runtime, graph=None): + self.dname, self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph, self.compiler_cachekey = \ + device, allocator, linearizer_opts, renderer, compiler, runtime, graph, None if getenv("DISABLE_COMPILER_CACHE") else compiler_cachekey def synchronize(self): pass # override this in your device def to_program(self, k:Linearizer) -> CompiledASTRunner: assert self.compiler is not None, f"compiler is None, can't build {k.ast}" k.linearize() - src = self.renderer(to_function_name(k.name), k.uops) - return CompiledASTRunner(k.ast, k.name, src, self, k.global_size, k.local_size) + return CompiledASTRunner(k.ast, k.name, self.renderer(to_function_name(k.name), k.uops), self, k.global_size, k.local_size) def get_linearizer(self, ast:LazyOp) -> Linearizer: if DEBUG >= 3: diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 702e0ca612..832dbf21af 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -22,5 +22,7 @@ class ClangProgram: def __call__(self, *bufs, vals=(), wait=False): return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait) -ClangDevice = Compiled(MallocAllocator, LinearizerOptions("CLANG", supports_float4=False, has_local=False), - functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict")), compile_clang, "compile_clang", ClangProgram) +class ClangDevice(Compiled): + def __init__(self, device:str): + super().__init__(device, MallocAllocator, LinearizerOptions("CLANG", supports_float4=False, has_local=False), + functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict")), compile_clang, "compile_clang", ClangProgram) diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 78d07f4d27..75babed640 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -42,4 +42,5 @@ class NumpyAllocator(Allocator): def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape)) def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src) -CPUDevice = Interpreted(NumpyAllocator(), numpy_fxn_for_op) +class CPUDevice(Interpreted): + def __init__(self, device:str): super().__init__(device, NumpyAllocator(), numpy_fxn_for_op) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 02033e1016..56f22e6834 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -83,7 +83,7 @@ class CUDADevice(Compiled): self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35" from tinygrad.runtime.graph.cuda import CUDAGraph - super().__init__(CUDAAllocator(self) if not CUDACPU else MallocAllocator, + super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator, LinearizerOptions("CUDA", supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]), CUDARenderer, functools.partial(compile_cuda,arch=self.arch), f"compile_cuda_{self.arch}", functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None) diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index b2cba18e96..1c8627964f 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -54,4 +54,4 @@ class DiskAllocator(Allocator): dest[:] = src._buf() class DiskDevice(Interpreted): - def __init__(self, device:str): super().__init__(DiskAllocator(device[len("disk:"):]), disk_fxn_for_op) \ No newline at end of file + def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), disk_fxn_for_op) \ No newline at end of file diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 4a6c860ed3..c4aa69599e 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -94,7 +94,7 @@ class CLDevice(Compiled): self.pending_copyin: List[memoryview] = [] compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest() - super().__init__(CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer, + super().__init__(device, CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer, functools.partial(compile_cl, self), f"compile_cl_{compile_key}", functools.partial(CLProgram, self)) def synchronize(self): check(cl.clFinish(self.queue)) diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 5174e2af95..b91c776fc9 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -104,7 +104,7 @@ class HIPDevice(Compiled): self.pending_events: List[hip.hipEvent_t] = [] from tinygrad.runtime.graph.hip import HIPGraph - super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer, + super().__init__(device, MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer, functools.partial(compile_hip,arch=self.arch), f"compile_hip_{self.arch}", functools.partial(HIPProgram, self.device), HIPGraph) def synchronize(self): check(hip.hipSetDevice(self.device)) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 2a78becd93..a5ec322dea 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -38,5 +38,5 @@ class LLVMDevice(Compiled): backing_mod = llvm.parse_assembly(str()) backing_mod.triple = llvm.get_process_triple() self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine) - super().__init__(MallocAllocator, LinearizerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False), + super().__init__(device, MallocAllocator, LinearizerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, functools.partial(compile_llvm, self), "compile_llvm", functools.partial(LLVMProgram, self)) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 21e385435f..2156f64e63 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -78,7 +78,7 @@ class MetalDevice(Compiled): self.mtl_buffers_in_flight: List[Any] = [] self.mv_in_metal: List[memoryview] = [] from tinygrad.runtime.graph.metal import MetalGraph - super().__init__(MetalAllocator(self), LinearizerOptions("METAL"), MetalRenderer, + super().__init__(device, MetalAllocator(self), LinearizerOptions("METAL"), MetalRenderer, compile_metal_xcode if getenv("METAL_XCODE") else functools.partial(compile_metal, self.device), "compile_metal", functools.partial(MetalProgram, self), functools.partial(MetalGraph, self)) def synchronize(self): diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index a6494e3771..ac704c6153 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -42,4 +42,5 @@ class TorchAllocator(Allocator): def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype)) def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten()) -TorchDevice = Interpreted(TorchAllocator(), torch_fxn_for_op) +class TorchDevice(Interpreted): + def __init__(self, device:str): super().__init__(device, TorchAllocator(), torch_fxn_for_op) \ No newline at end of file