mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
remove cpu graph, it's different from the others (#10743)
* remove cpu graph, it's different from the others * remote was blacklisting CPUGraph
This commit is contained in:
@@ -1,67 +0,0 @@
|
||||
from typing import cast, TypeVar, Generic, get_args as get_typing_args
|
||||
import itertools
|
||||
from tinygrad.helpers import dedup, flatten, DEBUG, to_function_name
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer, ldt
|
||||
|
||||
T = TypeVar('T')
|
||||
class BatchedGraph(Generic[T], GraphRunner):
|
||||
def __init__(self, device, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
||||
renderer_class = get_typing_args(getattr(self, "__orig_bases__")[0])[0]
|
||||
if not issubclass(type(device.renderer), renderer_class) and not isinstance(device.renderer, renderer_class): raise GraphException
|
||||
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
self.base_bufs = dedup(b.base for ji in jit_cache for b in ji.bufs if b is not None and b not in input_rawbuffers)
|
||||
self.base_rawbufs = [b._buf for b in self.base_bufs]
|
||||
|
||||
targs = [(f"arg{i}", x.dtype.ptr()) for i,x in enumerate(input_rawbuffers)] + \
|
||||
[(f"cbuf{i}", dtypes.char.ptr()) for i in range(len(self.base_bufs))] + \
|
||||
sorted([(f"{v.expr}", dtypes.int) for v in var_vals])
|
||||
code = self._prepare_code(device.renderer, jit_cache, input_rawbuffers, targs)
|
||||
if DEBUG >= 4: print(code)
|
||||
self.clprg = device.runtime("batched", device.compiler.compile_cached(code))
|
||||
|
||||
def _prepare_code(self, renderer:T, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str: return ""
|
||||
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
|
||||
return self.clprg(*[x._buf for x in rawbufs], *self.base_rawbufs, *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)], wait=wait)
|
||||
|
||||
class CPUGraph(BatchedGraph[ClangRenderer]):
|
||||
def _prepare_code(self, renderer:ClangRenderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str:
|
||||
def render_arg(buf):
|
||||
if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}"
|
||||
return f"({renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})"
|
||||
|
||||
batched = ["void batched("+','.join([f"{renderer.render_dtype(x[1])} {x[0]}" for x in targs])+") {"]
|
||||
for i, ji in enumerate(jit_cache):
|
||||
args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
|
||||
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});")
|
||||
batched.append("}")
|
||||
|
||||
prep = [renderer._render(cast(CompiledRunner, ji.prg).p.uops or []) for i,ji in enumerate(jit_cache)]
|
||||
funcs = dedup(renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops,
|
||||
["static", "__attribute__((always_inline))"]) for i,ji in enumerate(jit_cache))
|
||||
defines = dedup(itertools.chain.from_iterable(renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache))
|
||||
entry = renderer._render_entry("batched", [(t[0], (t[1], False)) for t in targs])
|
||||
return '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
|
||||
|
||||
class LLVMGraph(BatchedGraph[LLVMRenderer]):
|
||||
def _prepare_code(self, renderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str:
|
||||
out = []
|
||||
for i,ji in enumerate(jit_cache):
|
||||
args = []
|
||||
for j,buf in enumerate(cast(list[Buffer], ji.bufs)):
|
||||
arg = f"%arg{input_rawbuffers.index(buf)}" if buf in input_rawbuffers else f"%b{i}_{j}"
|
||||
if buf not in input_rawbuffers: out.append(f" {arg} = getelementptr inbounds i8,ptr %cbuf{self.base_bufs.index(buf.base)},i64 {buf.offset}")
|
||||
args.append(f"{ldt(buf.dtype.ptr())} {arg}")
|
||||
args += [f"{ldt(x.dtype)} %{x.expr}" for x in cast(CompiledRunner, ji.prg).p.vars]
|
||||
out.append(f" call void @{to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)})")
|
||||
|
||||
kernels = dedup(tuple(renderer._render_kernel(cast(CompiledRunner, ji.prg).p.uops, ["internal"]) for i,ji in enumerate(jit_cache)))
|
||||
kernels += [((), renderer._render_fn("batched", [(f"%{x[0]}", x[1]) for x in targs], out))]
|
||||
assert flatten(x[0] for x in kernels) == [] # global definitions are not used in CPU mode right now
|
||||
return "\n".join([x[1] for x in kernels] + [renderer._render_footer(cast(CompiledRunner, ji.prg).p.uops)])
|
||||
@@ -1,4 +1,4 @@
|
||||
import functools, platform, subprocess, sys
|
||||
import platform, subprocess, sys
|
||||
from tinygrad.helpers import capstone_flatdump, getenv
|
||||
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
@@ -18,9 +18,5 @@ class ClangJITCompiler(Compiler):
|
||||
|
||||
def disassemble(self, lib:bytes): return capstone_flatdump(lib)
|
||||
|
||||
class ClangDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
from tinygrad.runtime.graph.cpu import CPUGraph
|
||||
super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram, functools.partial(CPUGraph, self))
|
||||
|
||||
CPUDevice = ClangDevice
|
||||
class CPUDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import functools, ctypes, platform
|
||||
import ctypes, platform
|
||||
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
|
||||
from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer
|
||||
@@ -70,6 +70,4 @@ class HostLLVMCompiler(LLVMCompiler):
|
||||
super().__init__(cpu.decode(), feats.decode())
|
||||
|
||||
class LLVMDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
from tinygrad.runtime.graph.cpu import LLVMGraph
|
||||
super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram, functools.partial(LLVMGraph, self))
|
||||
def __init__(self, device:str): super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram)
|
||||
|
||||
@@ -16,7 +16,6 @@ from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
|
||||
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferXfer
|
||||
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec
|
||||
from tinygrad.runtime.graph.cpu import CPUGraph
|
||||
|
||||
# ***** API *****
|
||||
|
||||
@@ -171,8 +170,7 @@ class RemoteHandler:
|
||||
case SessionFree(): del self.sessions[unwrap(c.session)]
|
||||
case GetProperties():
|
||||
cls, args = dev.renderer.__reduce__()
|
||||
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
|
||||
graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None
|
||||
graph_cls = graph_class(Device[self.base_device])
|
||||
rp = RemoteProperties(
|
||||
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args),
|
||||
graph_supported=graph_cls is not None, graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner),
|
||||
|
||||
Reference in New Issue
Block a user