From 3d64a984326ca224cb93904af280a3530a2268e0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 9 Jun 2025 22:17:10 -0700 Subject: [PATCH] remove cpu graph, it's different from the others (#10743) * remove cpu graph, it's different from the others * remote was blacklisting CPUGraph --- tinygrad/runtime/graph/cpu.py | 67 ---------------------------------- tinygrad/runtime/ops_cpu.py | 10 ++--- tinygrad/runtime/ops_llvm.py | 6 +-- tinygrad/runtime/ops_remote.py | 4 +- 4 files changed, 6 insertions(+), 81 deletions(-) delete mode 100644 tinygrad/runtime/graph/cpu.py diff --git a/tinygrad/runtime/graph/cpu.py b/tinygrad/runtime/graph/cpu.py deleted file mode 100644 index 390c35a204..0000000000 --- a/tinygrad/runtime/graph/cpu.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import cast, TypeVar, Generic, get_args as get_typing_args -import itertools -from tinygrad.helpers import dedup, flatten, DEBUG, to_function_name -from tinygrad.engine.jit import GraphRunner, GraphException -from tinygrad.device import Buffer -from tinygrad.engine.realize import ExecItem, CompiledRunner -from tinygrad.uop.ops import Variable -from tinygrad.dtype import DType, dtypes -from tinygrad.renderer.cstyle import ClangRenderer -from tinygrad.renderer.llvmir import LLVMRenderer, ldt - -T = TypeVar('T') -class BatchedGraph(Generic[T], GraphRunner): - def __init__(self, device, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]): - renderer_class = get_typing_args(getattr(self, "__orig_bases__")[0])[0] - if not issubclass(type(device.renderer), renderer_class) and not isinstance(device.renderer, renderer_class): raise GraphException - - super().__init__(jit_cache, input_rawbuffers, var_vals) - self.base_bufs = dedup(b.base for ji in jit_cache for b in ji.bufs if b is not None and b not in input_rawbuffers) - self.base_rawbufs = [b._buf for b in self.base_bufs] - - targs = [(f"arg{i}", x.dtype.ptr()) for i,x in enumerate(input_rawbuffers)] + \ - [(f"cbuf{i}", dtypes.char.ptr()) for i in range(len(self.base_bufs))] + \ - sorted([(f"{v.expr}", dtypes.int) for v in var_vals]) - code = self._prepare_code(device.renderer, jit_cache, input_rawbuffers, targs) - if DEBUG >= 4: print(code) - self.clprg = device.runtime("batched", device.compiler.compile_cached(code)) - - def _prepare_code(self, renderer:T, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str: return "" - def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False): - return self.clprg(*[x._buf for x in rawbufs], *self.base_rawbufs, *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)], wait=wait) - -class CPUGraph(BatchedGraph[ClangRenderer]): - def _prepare_code(self, renderer:ClangRenderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str: - def render_arg(buf): - if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}" - return f"({renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})" - - batched = ["void batched("+','.join([f"{renderer.render_dtype(x[1])} {x[0]}" for x in targs])+") {"] - for i, ji in enumerate(jit_cache): - args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars] - batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});") - batched.append("}") - - prep = [renderer._render(cast(CompiledRunner, ji.prg).p.uops or []) for i,ji in enumerate(jit_cache)] - funcs = dedup(renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops, - ["static", "__attribute__((always_inline))"]) for i,ji in enumerate(jit_cache)) - defines = dedup(itertools.chain.from_iterable(renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)) - entry = renderer._render_entry("batched", [(t[0], (t[1], False)) for t in targs]) - return '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry - -class LLVMGraph(BatchedGraph[LLVMRenderer]): - def _prepare_code(self, renderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str: - out = [] - for i,ji in enumerate(jit_cache): - args = [] - for j,buf in enumerate(cast(list[Buffer], ji.bufs)): - arg = f"%arg{input_rawbuffers.index(buf)}" if buf in input_rawbuffers else f"%b{i}_{j}" - if buf not in input_rawbuffers: out.append(f" {arg} = getelementptr inbounds i8,ptr %cbuf{self.base_bufs.index(buf.base)},i64 {buf.offset}") - args.append(f"{ldt(buf.dtype.ptr())} {arg}") - args += [f"{ldt(x.dtype)} %{x.expr}" for x in cast(CompiledRunner, ji.prg).p.vars] - out.append(f" call void @{to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)})") - - kernels = dedup(tuple(renderer._render_kernel(cast(CompiledRunner, ji.prg).p.uops, ["internal"]) for i,ji in enumerate(jit_cache))) - kernels += [((), renderer._render_fn("batched", [(f"%{x[0]}", x[1]) for x in targs], out))] - assert flatten(x[0] for x in kernels) == [] # global definitions are not used in CPU mode right now - return "\n".join([x[1] for x in kernels] + [renderer._render_footer(cast(CompiledRunner, ji.prg).p.uops)]) diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index e180a0a6b3..c5a15afb52 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -1,4 +1,4 @@ -import functools, platform, subprocess, sys +import platform, subprocess, sys from tinygrad.helpers import capstone_flatdump, getenv from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram from tinygrad.runtime.support.elf import jit_loader @@ -18,9 +18,5 @@ class ClangJITCompiler(Compiler): def disassemble(self, lib:bytes): return capstone_flatdump(lib) -class ClangDevice(Compiled): - def __init__(self, device:str): - from tinygrad.runtime.graph.cpu import CPUGraph - super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram, functools.partial(CPUGraph, self)) - -CPUDevice = ClangDevice \ No newline at end of file +class CPUDevice(Compiled): + def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 1c1cb6ba5c..c628ad138b 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -1,4 +1,4 @@ -import functools, ctypes, platform +import ctypes, platform from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG from tinygrad.renderer.llvmir import LLVMRenderer @@ -70,6 +70,4 @@ class HostLLVMCompiler(LLVMCompiler): super().__init__(cpu.decode(), feats.decode()) class LLVMDevice(Compiled): - def __init__(self, device:str): - from tinygrad.runtime.graph.cpu import LLVMGraph - super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram, functools.partial(LLVMGraph, self)) + def __init__(self, device:str): super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram) diff --git a/tinygrad/runtime/ops_remote.py b/tinygrad/runtime/ops_remote.py index 00440fd6a8..991da324c6 100644 --- a/tinygrad/runtime/ops_remote.py +++ b/tinygrad/runtime/ops_remote.py @@ -16,7 +16,6 @@ from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class from tinygrad.engine.realize import CompiledRunner, BufferXfer from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec -from tinygrad.runtime.graph.cpu import CPUGraph # ***** API ***** @@ -171,8 +170,7 @@ class RemoteHandler: case SessionFree(): del self.sessions[unwrap(c.session)] case GetProperties(): cls, args = dev.renderer.__reduce__() - # CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported - graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None + graph_cls = graph_class(Device[self.base_device]) rp = RemoteProperties( real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), graph_supported=graph_cls is not None, graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner),