mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
generic LLVMRenderer class for CPU and AMD (#14321)
* make generic llvmrenderer class for cpu and amd * move `tensor_cores` back to parent * remove empty line * restore extra matcher position * cleanup --------- Co-authored-by: TheVanadium <claude_user@ret2022.localdomain>
This commit is contained in:
@@ -133,18 +133,13 @@ base_rewrite = PatternMatcher([
|
||||
])
|
||||
|
||||
class LLVMRenderer(Renderer):
|
||||
device = "CPU"
|
||||
abi = 'win64cc' if sys.platform == 'win32' else None
|
||||
supports_float4 = True
|
||||
has_local = False
|
||||
global_max: tuple[int, ...] | None = None
|
||||
string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)])
|
||||
abi: str | None
|
||||
string_rewrite: PatternMatcher
|
||||
code_for_op = {Ops.FDIV: lambda: None, Ops.CMPLT: lambda: None}
|
||||
if AMX: tensor_cores = tc.amx
|
||||
|
||||
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
|
||||
def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
|
||||
def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
|
||||
def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
|
||||
# NOTE: CPUAllocator promises 0x20 alignment
|
||||
sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args])
|
||||
@@ -181,11 +176,11 @@ class LLVMRenderer(Renderer):
|
||||
assert isinstance(u.dtype, PtrDType)
|
||||
if u.op is Ops.DEFINE_REG:
|
||||
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}]")
|
||||
elif self.device == "CPU" and u.op is Ops.DEFINE_LOCAL:
|
||||
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}], align 16")
|
||||
else:
|
||||
elif self.has_local:
|
||||
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
|
||||
kernel.append(f" {r[u]} = addrspacecast [{u.dtype.size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{u.dtype.size} x {ldt(u.dtype)}]*")
|
||||
else:
|
||||
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}], align 16")
|
||||
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
||||
elif u.op is Ops.CAST and (ldt(u.dtype) == ldt(u.src[0].dtype) or isinstance(u.dtype, PtrDType)):
|
||||
r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop, or pointer cast
|
||||
@@ -201,6 +196,15 @@ class LLVMRenderer(Renderer):
|
||||
kernel.append(cast(str, l))
|
||||
return tuple(local_args), self._render_fn(name, args, kernel, prefix)
|
||||
|
||||
class CPULLVMRenderer(LLVMRenderer):
|
||||
device = "CPU"
|
||||
has_local = False
|
||||
global_max: tuple[int, ...] | None = None
|
||||
abi = 'win64cc' if sys.platform == 'win32' else None
|
||||
string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)])
|
||||
def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
|
||||
def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
|
||||
|
||||
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
|
||||
code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()",
|
||||
"l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"}
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.device import BufferSpec, DMACPURef, CompilerSet, CompilerPair
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
|
||||
from tinygrad.runtime.support.hcq import CLikeArgsState
|
||||
from tinygrad.renderer.cstyle import ClangJITRenderer
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer
|
||||
from tinygrad.renderer.llvmir import CPULLVMRenderer
|
||||
from tinygrad.renderer.nir import LVPRenderer
|
||||
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
@@ -135,6 +135,6 @@ class CPUDevice(HCQCompiled):
|
||||
def __init__(self, device:str=""):
|
||||
self.tasks:queue.Queue = queue.Queue()
|
||||
CPUWorker(self, self.tasks, thread_id=0).start()
|
||||
compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
|
||||
compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(CPULLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
|
||||
CompilerPair(LVPRenderer, None, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC)
|
||||
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
|
||||
|
||||
Reference in New Issue
Block a user