diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index de0eddb7e3..da9b4de780 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -141,11 +141,17 @@ class LLVMRenderer(Renderer): (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None), ]) - def render(self, uops: list[UOp]) -> str: + def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer())) + def _render_footer(self) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }' + def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str: + # NOTE: MallocAllocator promises 0x20 alignment + sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args]) + sprefix = "".join([f" {x}" for x in (prefix or []) + [self.abi] if x is not None]) + return "\n".join([f"define{sprefix} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"]) + def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]: r: dict[UOp, str] = {} - args: list[str] = [] + args: list[tuple[str, DType]] = [] kernel: list[str] = [] - end_lines: dict[str, None] = {} vc = -1 local_args: list[str] = [] @@ -170,8 +176,7 @@ class LLVMRenderer(Renderer): continue if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}" - # NOTE: MallocAllocator promises 0x20 alignment - args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}") + args.append((r[u], u.dtype)) elif u.op == Ops.DEFINE_LOCAL: r[u] = f"%local_{u.arg}" assert isinstance(u.dtype, PtrDType) @@ -201,17 +206,7 @@ class LLVMRenderer(Renderer): vc += 1 kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]") r[x] = f"%acc{vc}" - - # output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings) - prg = f'''\ -define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{ -{chr(10).join(kernel)} - ret void -}} -{chr(10).join(end_lines.keys())} -attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }} -''' - return prg if len(local_args) == 0 else "\n".join(local_args)+f"\n{prg}" + return tuple(local_args), self._render_fn(name, args, kernel, prefix) barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n' code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()", diff --git a/tinygrad/runtime/graph/cpu.py b/tinygrad/runtime/graph/cpu.py index 0e71f449d6..4ea41e05eb 100644 --- a/tinygrad/runtime/graph/cpu.py +++ b/tinygrad/runtime/graph/cpu.py @@ -1,45 +1,67 @@ -from typing import cast +from typing import cast, TypeVar, Generic, get_args as get_typing_args import itertools -from tinygrad.helpers import dedup, DEBUG, to_function_name +from tinygrad.helpers import dedup, flatten, DEBUG, to_function_name from tinygrad.engine.jit import GraphRunner, GraphException from tinygrad.device import Buffer from tinygrad.engine.realize import ExecItem, CompiledRunner from tinygrad.ops import Variable -from tinygrad.dtype import dtypes +from tinygrad.dtype import DType, dtypes from tinygrad.renderer.cstyle import ClangRenderer +from tinygrad.renderer.llvmir import LLVMRenderer, ldt -class CPUGraph(GraphRunner): +T = TypeVar('T') +class BatchedGraph(Generic[T], GraphRunner): def __init__(self, device, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]): - if not issubclass(type(device.renderer), ClangRenderer) and not isinstance(device.renderer, ClangRenderer): raise GraphException - super().__init__(jit_cache, input_rawbuffers, var_vals) + renderer_class = get_typing_args(getattr(self, "__orig_bases__")[0])[0] + if not issubclass(type(device.renderer), renderer_class) and not isinstance(device.renderer, renderer_class): raise GraphException + super().__init__(jit_cache, input_rawbuffers, var_vals) self.base_bufs = dedup(b.base for ji in jit_cache for b in ji.bufs if b is not None and b not in input_rawbuffers) self.base_rawbufs = [b._buf for b in self.base_bufs] - targs = [(f"arg{i}", (x.dtype.ptr(), False)) for i,x in enumerate(input_rawbuffers)] + \ - [(f"cbuf{i}", (dtypes.char.ptr(), False)) for i in range(len(self.base_bufs))] + \ - sorted([(f"{v.expr}", (dtypes.int, False)) for v in var_vals]) + targs = [(f"arg{i}", x.dtype.ptr()) for i,x in enumerate(input_rawbuffers)] + \ + [(f"cbuf{i}", dtypes.char.ptr()) for i in range(len(self.base_bufs))] + \ + sorted([(f"{v.expr}", dtypes.int) for v in var_vals]) + code = self._prepare_code(device.renderer, jit_cache, input_rawbuffers, targs) + if DEBUG >= 4: print(code) + self.clprg = device.runtime("batched", device.compiler.compile_cached(code)) + def _prepare_code(self, renderer:T, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str: return "" + def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False): + return self.clprg(*[x._buf for x in rawbufs], *self.base_rawbufs, *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)], wait=wait) + +class CPUGraph(BatchedGraph[ClangRenderer]): + def _prepare_code(self, renderer:ClangRenderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str: def render_arg(buf): if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}" - return f"({device.renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})" + return f"({renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})" - batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+") {"] + batched = ["void batched("+','.join([f"{renderer.render_dtype(x[1])} {x[0]}" for x in targs])+") {"] for i, ji in enumerate(jit_cache): args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars] batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});") batched.append("}") - prep = [device.renderer._render(cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)] - funcs = dedup(device.renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops, - ["static", "__attribute__((always_inline))"]) for i,ji in enumerate(jit_cache)) + prep = [renderer._render(cast(CompiledRunner, ji.prg).p.uops or []) for i,ji in enumerate(jit_cache)] + funcs = dedup(renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops, + ["static", "__attribute__((always_inline))"]) for i,ji in enumerate(jit_cache)) + defines = dedup(itertools.chain.from_iterable(renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)) + entry = renderer._render_entry("batched", [(t[0], (t[1], False)) for t in targs]) + return '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry - defines = dedup(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)) - entry = device.renderer._render_entry("batched", targs) - code = '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry +class LLVMGraph(BatchedGraph[LLVMRenderer]): + def _prepare_code(self, renderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str: + out = [] + for i,ji in enumerate(jit_cache): + args = [] + for j,buf in enumerate(cast(list[Buffer], ji.bufs)): + arg = f"%arg{input_rawbuffers.index(buf)}" if buf in input_rawbuffers else f"%b{i}_{j}" + if buf not in input_rawbuffers: out.append(f" {arg} = getelementptr inbounds i8,ptr %cbuf{self.base_bufs.index(buf.base)},i64 {buf.offset}") + args.append(f"{ldt(buf.dtype.ptr())} {arg}") + args += [f"{ldt(x.dtype)} %{x.expr}" for x in cast(CompiledRunner, ji.prg).p.vars] + out.append(f" call void @{to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)})") - if DEBUG >= 4: print(code) - self.clprg = device.runtime("batched", device.compiler.compile_cached(code)) - - def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False): - return self.clprg(*[x._buf for x in rawbufs], *self.base_rawbufs, *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)], wait=wait) + kernels = dedup(tuple(renderer._render_kernel(cast(CompiledRunner, ji.prg).p.uops, ["internal"]) for i,ji in enumerate(jit_cache))) + kernels += [((), renderer._render_fn("batched", [(f"%{x[0]}", x[1]) for x in targs], out))] + assert flatten(x[0] for x in kernels) == [] # global definitions are not used in CPU mode right now + return "\n".join([x[1] for x in kernels] + [renderer._render_footer()]) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 5015ba5927..5dfad557aa 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -1,4 +1,4 @@ -import ctypes, platform +import functools, ctypes, platform from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG from tinygrad.renderer.llvmir import LLVMRenderer @@ -60,4 +60,5 @@ class HostLLVMCompiler(LLVMCompiler): class LLVMDevice(Compiled): def __init__(self, device:str): - super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram) + from tinygrad.runtime.graph.cpu import LLVMGraph + super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram, functools.partial(LLVMGraph, self))