remove cpu graph, it's different from the others (#10743)

* remove cpu graph, it's different from the others

* remote was blacklisting CPUGraph
This commit is contained in:
George Hotz
2025-06-09 22:17:10 -07:00
committed by GitHub
parent 245b1d3a46
commit 3d64a98432
4 changed files with 6 additions and 81 deletions

View File

@@ -1,67 +0,0 @@
from typing import cast, TypeVar, Generic, get_args as get_typing_args
import itertools
from tinygrad.helpers import dedup, flatten, DEBUG, to_function_name
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.device import Buffer
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.uop.ops import Variable
from tinygrad.dtype import DType, dtypes
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.renderer.llvmir import LLVMRenderer, ldt
T = TypeVar('T')
class BatchedGraph(Generic[T], GraphRunner):
def __init__(self, device, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
renderer_class = get_typing_args(getattr(self, "__orig_bases__")[0])[0]
if not issubclass(type(device.renderer), renderer_class) and not isinstance(device.renderer, renderer_class): raise GraphException
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.base_bufs = dedup(b.base for ji in jit_cache for b in ji.bufs if b is not None and b not in input_rawbuffers)
self.base_rawbufs = [b._buf for b in self.base_bufs]
targs = [(f"arg{i}", x.dtype.ptr()) for i,x in enumerate(input_rawbuffers)] + \
[(f"cbuf{i}", dtypes.char.ptr()) for i in range(len(self.base_bufs))] + \
sorted([(f"{v.expr}", dtypes.int) for v in var_vals])
code = self._prepare_code(device.renderer, jit_cache, input_rawbuffers, targs)
if DEBUG >= 4: print(code)
self.clprg = device.runtime("batched", device.compiler.compile_cached(code))
def _prepare_code(self, renderer:T, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str: return ""
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
return self.clprg(*[x._buf for x in rawbufs], *self.base_rawbufs, *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)], wait=wait)
class CPUGraph(BatchedGraph[ClangRenderer]):
def _prepare_code(self, renderer:ClangRenderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str:
def render_arg(buf):
if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}"
return f"({renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})"
batched = ["void batched("+','.join([f"{renderer.render_dtype(x[1])} {x[0]}" for x in targs])+") {"]
for i, ji in enumerate(jit_cache):
args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});")
batched.append("}")
prep = [renderer._render(cast(CompiledRunner, ji.prg).p.uops or []) for i,ji in enumerate(jit_cache)]
funcs = dedup(renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops,
["static", "__attribute__((always_inline))"]) for i,ji in enumerate(jit_cache))
defines = dedup(itertools.chain.from_iterable(renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache))
entry = renderer._render_entry("batched", [(t[0], (t[1], False)) for t in targs])
return '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
class LLVMGraph(BatchedGraph[LLVMRenderer]):
def _prepare_code(self, renderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str:
out = []
for i,ji in enumerate(jit_cache):
args = []
for j,buf in enumerate(cast(list[Buffer], ji.bufs)):
arg = f"%arg{input_rawbuffers.index(buf)}" if buf in input_rawbuffers else f"%b{i}_{j}"
if buf not in input_rawbuffers: out.append(f" {arg} = getelementptr inbounds i8,ptr %cbuf{self.base_bufs.index(buf.base)},i64 {buf.offset}")
args.append(f"{ldt(buf.dtype.ptr())} {arg}")
args += [f"{ldt(x.dtype)} %{x.expr}" for x in cast(CompiledRunner, ji.prg).p.vars]
out.append(f" call void @{to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)})")
kernels = dedup(tuple(renderer._render_kernel(cast(CompiledRunner, ji.prg).p.uops, ["internal"]) for i,ji in enumerate(jit_cache)))
kernels += [((), renderer._render_fn("batched", [(f"%{x[0]}", x[1]) for x in targs], out))]
assert flatten(x[0] for x in kernels) == [] # global definitions are not used in CPU mode right now
return "\n".join([x[1] for x in kernels] + [renderer._render_footer(cast(CompiledRunner, ji.prg).p.uops)])

View File

@@ -1,4 +1,4 @@
import functools, platform, subprocess, sys
import platform, subprocess, sys
from tinygrad.helpers import capstone_flatdump, getenv
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
from tinygrad.runtime.support.elf import jit_loader
@@ -18,9 +18,5 @@ class ClangJITCompiler(Compiler):
def disassemble(self, lib:bytes): return capstone_flatdump(lib)
class ClangDevice(Compiled):
def __init__(self, device:str):
from tinygrad.runtime.graph.cpu import CPUGraph
super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram, functools.partial(CPUGraph, self))
CPUDevice = ClangDevice
class CPUDevice(Compiled):
def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram)

View File

@@ -1,4 +1,4 @@
import functools, ctypes, platform
import ctypes, platform
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG
from tinygrad.renderer.llvmir import LLVMRenderer
@@ -70,6 +70,4 @@ class HostLLVMCompiler(LLVMCompiler):
super().__init__(cpu.decode(), feats.decode())
class LLVMDevice(Compiled):
def __init__(self, device:str):
from tinygrad.runtime.graph.cpu import LLVMGraph
super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram, functools.partial(LLVMGraph, self))
def __init__(self, device:str): super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram)

View File

@@ -16,7 +16,6 @@ from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferXfer
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec
from tinygrad.runtime.graph.cpu import CPUGraph
# ***** API *****
@@ -171,8 +170,7 @@ class RemoteHandler:
case SessionFree(): del self.sessions[unwrap(c.session)]
case GetProperties():
cls, args = dev.renderer.__reduce__()
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None
graph_cls = graph_class(Device[self.base_device])
rp = RemoteProperties(
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args),
graph_supported=graph_cls is not None, graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner),