NIR: new-style (fix beam) (#13887)

* NIR: fix beam

* new reduce

* Revert "Revert "NIR: new-style compilers (#13875)" (#13888)"

This reverts commit fc4faed0b2.

* oops
This commit is contained in:
Christopher Milan
2025-12-29 15:41:29 -08:00
committed by GitHub
parent fc4faed0b2
commit 0497387e45
7 changed files with 34 additions and 59 deletions

View File

@@ -1,6 +1,6 @@
from typing import Callable, cast, Any
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes
from tinygrad.helpers import DEBUG, OSX, unwrap, charptr
from tinygrad.helpers import DEBUG, OSX, unwrap, charptr, fromimport
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
@@ -115,7 +115,8 @@ def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
class NIRRenderer(Renderer):
suffix = "NAK"
suffix = "NIR"
nir_options: bytes
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
code_for_op = {**{k:lambda:None for k in u_aop.keys()}, **{k:lambda:None for k in s_aop.keys()}, **{k:lambda:None for k in f_aop.keys()}}
@@ -158,13 +159,17 @@ class NIRRenderer(Renderer):
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
])
def __init__(self): mesa.glsl_type_singleton_init_or_ref()
def __reduce__(self): return self.__class__, self.args
def __init__(self, *args):
self.compiler = fromimport("tinygrad.runtime.support.compiler_mesa", self.__class__.__name__.replace("Renderer", "Compiler"))(*args)
self.args = args
if hasattr(self.compiler, "nir_options"): self.nir_options = self.compiler.nir_options
mesa.glsl_type_singleton_init_or_ref()
def __del__(self):
with contextlib.suppress(AttributeError): mesa.glsl_type_singleton_decref()
@property
def nir_options(self): raise NotImplementedError("needs nir_options")
def param(self, b:mesa.nir_builder, x, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param")
def prerender(self, uops:list[UOp]):
self.b = mesa.nir_builder_init_simple_shader(mesa.MESA_SHADER_COMPUTE, mesa.nir_shader_compiler_options.from_buffer_copy(self.nir_options), None)
@@ -216,20 +221,9 @@ class NIRRenderer(Renderer):
return ret
class NIRRendererWithOpts(NIRRenderer):
def __init__(self, dev=None, nir_options=None):
self.dev, self._nir_options = dev, nir_options
super().__init__()
def __reduce__(self): return self.__class__, (None, self.nir_options)
@property
def nir_options(self):
if self._nir_options is None: self._nir_options = self.dev.compiler.nir_options
return self._nir_options
class NAKRenderer(NIRRendererWithOpts):
class NAKRenderer(NIRRenderer):
device = "NV"
param = nir_instr(nc=1, num_components=1, bs=lambda sz:sz*8, also=lambda self,sz: setattr(self, "param_idx", self.param_idx + sz),
intrins={"ALIGN_MUL":lambda sz:sz}, srcs=lambda self,b: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))])(
lambda self, b, x, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_ldc_nv))
@@ -261,7 +255,7 @@ _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2
nc=4, bs=32, num_components=4, srcs=lambda b,img,coord:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
lambda b,img,coord,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
class IR3Renderer(NIRRendererWithOpts):
class IR3Renderer(NIRRenderer):
device = "QCOM"
def nload_img(ctx,img,coord):

View File

@@ -9,7 +9,6 @@ from tinygrad.renderer.cstyle import ClangJITRenderer
from tinygrad.renderer.llvmir import LLVMRenderer
from tinygrad.renderer.nir import LVPRenderer
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
from tinygrad.runtime.support.compiler_mesa import LVPCompiler
from tinygrad.runtime.support.elf import jit_loader
from tinygrad.uop.ops import sint
@@ -72,7 +71,7 @@ class CPUProgram(HCQProgram):
except OSError: pass
def __init__(self, dev, name:str, lib:bytes):
LVP = isinstance(dev.compiler, LVPCompiler)
LVP = isinstance(dev.renderer, LVPRenderer)
if sys.platform == "win32": # mypy doesn't understand when WIN is used here
PAGE_EXECUTE_READWRITE, MEM_COMMIT, MEM_RESERVE = 0x40, 0x1000, 0x2000
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
@@ -137,5 +136,5 @@ class CPUDevice(HCQCompiled):
self.tasks:queue.Queue = queue.Queue()
CPUWorker(self, self.tasks, thread_id=0).start()
compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
CompilerPair(LVPRenderer, LVPCompiler, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC)
CompilerPair(LVPRenderer, None, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC)
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)

View File

@@ -6,7 +6,6 @@ from tinygrad.renderer.llvmir import AMDLLVMRenderer
from tinygrad.uop.ops import Ops
from tinygrad.helpers import cpu_profile, EMULATE, NULL_IR3, NULL_NAK
from tinygrad.renderer.nir import IR3Renderer, NAKRenderer
from tinygrad.runtime.support.compiler_mesa import IR3Compiler, NAKCompiler
class NullRenderer(CStyleLanguage):
device = "NULL"
@@ -39,7 +38,6 @@ class NullDevice(Compiled):
case "AMD_RDNA4": renderer = functools.partial(AMDLLVMRenderer, "gfx1201")
case "": renderer = NullRenderer
case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}")
compilers = CompilerSet([CompilerPair(renderer, Compiler),
CompilerPair(functools.partial(IR3Renderer, self), functools.partial(IR3Compiler, 0x6030001), NULL_IR3), # adreno 630
CompilerPair(functools.partial(NAKRenderer, self), functools.partial(NAKCompiler, "sm_120", 48), NULL_NAK)]) # 5090
compilers = CompilerSet([CompilerPair(renderer, Compiler), CompilerPair(functools.partial(IR3Renderer, 0x6030001), None, NULL_IR3), # adreno 630
CompilerPair(functools.partial(NAKRenderer, "sm_120", 48), None, NULL_NAK)]) # 5090
super().__init__(device, NullAllocator(self), compilers, functools.partial(NullProgram, device), NullGraph)

View File

@@ -11,7 +11,6 @@ from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, pr
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import NVRenderer
from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, NVPTXCompiler, NVCompiler
from tinygrad.runtime.support.compiler_mesa import NAKCompiler
from tinygrad.runtime.autogen import nv_570, nv_580, pci, mesa
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager
@@ -216,7 +215,7 @@ class NVProgram(HCQProgram):
self.dev, self.name, self.lib = dev, name, lib
self.constbufs: dict[int, tuple[int, int]] = {0: (0, 0x160)} # dict[constbuf index, tuple[va_addr, size]]
if (NAK:=isinstance(dev.compiler, NAKCompiler)):
if (NAK:=isinstance(dev.renderer, NAKRenderer)):
image, self.cbuf_0 = memoryview(bytearray(lib[ctypes.sizeof(info:=mesa.struct_nak_shader_info.from_buffer_copy(lib)):])), []
self.regs_usage, self.shmem_usage, self.lcmem_usage = info.num_gprs, round_up(info.cs.smem_size, 128), round_up(info.slm_size, 16)
elif MOCKGPU: image, sections, relocs = memoryview(bytearray(lib) + b'\x00' * (4 - len(lib)%4)).cast("I"), [], [] # type: ignore
@@ -586,7 +585,7 @@ class NVDevice(HCQCompiled[HCQSignal]):
cucc, ptxcc = (CUDACompiler, PTXCompiler) if MOCKGPU else (NVCompiler, NVPTXCompiler)
compilers = CompilerSet(ctrl_var=NV_CC, cset=[CompilerPair(functools.partial(NVRenderer, self.arch),functools.partial(cucc, self.arch)),
CompilerPair(functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(ptxcc, self.arch), NV_PTX),
CompilerPair(functools.partial(NAKRenderer, dev=self), functools.partial(NAKCompiler, self.arch, self.max_warps_per_sm), NV_NAK)])
CompilerPair(functools.partial(NAKRenderer, self.arch, self.max_warps_per_sm), None, NV_NAK)])
super().__init__(device, NVAllocator(self), compilers, functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue)
self._setup_gpfifos()

View File

@@ -10,7 +10,6 @@ from tinygrad.runtime.autogen import kgsl, mesa
from tinygrad.runtime.ops_cl import CLCompiler, CLDevice
from tinygrad.renderer.cstyle import QCOMRenderer
from tinygrad.renderer.nir import IR3Renderer
from tinygrad.runtime.support.compiler_mesa import IR3Compiler
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing
from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC
from tinygrad.runtime.support.system import System
@@ -227,10 +226,10 @@ class IR3ArgsState(HCQArgsState):
class QCOMProgram(HCQProgram):
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
self.dev: QCOMDevice = dev
self.name, self.lib, self.NIR = name, lib, isinstance(dev.compiler, IR3Compiler)
self.name, self.lib, self.NIR = name, lib, isinstance(dev.renderer, IR3Renderer)
if self.NIR:
from tinygrad.runtime.autogen import mesa
from tinygrad.runtime.support.compiler_mesa import IR3Compiler
v, cs, self.imm_vals, self.image = IR3Compiler.unpack_lib(lib)
self.prg_offset, self.brnchstck, self.image_size, self.pvtmem, self.shmem = 0, v.branchstack, v.info.size, v.pvtmem_size, v.shared_size
self.wgsz = alloc.offset_vec4 * 4 + 8 if (alloc:=cs.allocs.consts[mesa.IR3_CONST_ALLOC_DRIVER_PARAMS]).size_vec4 else 0xfc
@@ -402,7 +401,7 @@ class QCOMDevice(HCQCompiled):
System.write_sysfs("/sys/class/kgsl/kgsl-3d0/idle_timer", value="4000000000", msg="Failed to disable suspend mode", expected="4294967276")
compilers = CompilerSet(ctrl_var=QCOM_CC, cset=[CompilerPair(QCOMRenderer, functools.partial(QCOMCompiler, device)),
CompilerPair(functools.partial(IR3Renderer, self), functools.partial(IR3Compiler, info.chip_id), QCOM_IR3)])
CompilerPair(functools.partial(IR3Renderer, info.chip_id), None, QCOM_IR3)])
super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal,
functools.partial(QCOMComputeQueue, self), None)

View File

@@ -29,7 +29,7 @@ def expect(x, err, ret=None):
class LLVMCompiler(Compiler):
jit = True
target_arch = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86', 'riscv64': 'riscv64'}[platform.machine()]
def __init__(self, processor:str, feats:str):
def __init__(self, processor:str, feats:str, cache_key=None):
for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmParser', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{self.target_arch}{component}')()
triple = {'AArch64': b'aarch64-none-unknown-elf', 'X86': b'x86_64-none-unknown-elf', 'AMDGPU': b'amdgcn-amd-amdhsa'}[self.target_arch]
@@ -59,7 +59,7 @@ class LLVMCompiler(Compiler):
self.diag_msgs.append(msg)
self.handle_diag = handle_diag
llvm.LLVMContextSetDiagnosticHandler(self.context, handle_diag, None)
super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
super().__init__(cache_key or f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
def __del__(self):
llvm.LLVMDisposePassBuilderOptions(self.pbo)
@@ -83,7 +83,7 @@ class LLVMCompiler(Compiler):
def disassemble(self, lib:bytes): capstone_flatdump(lib)
class CPULLVMCompiler(LLVMCompiler):
def __init__(self):
def __init__(self, cache_key=None):
# +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx
cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures())
super().__init__(cpu.decode(), feats.decode())
super().__init__(cpu.decode(), feats.decode(), cache_key)

View File

@@ -4,6 +4,8 @@ from tinygrad.helpers import cpu_objdump, system, data64
from tinygrad.runtime.autogen import mesa, llvm
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, expect, cerr
# NB: compilers assume mesa's glsl type cache is managed externally with mesa.glsl_type_singleton_init_or_ref() and mesa.glsl_type_singleton_decref()
def rzalloc(typ, ctx=None, **kwargs):
s = ctypes.cast(mesa.rzalloc_size(ctypes.cast(ctx, ctypes.c_void_p), ctypes.sizeof(typ)), ctypes.POINTER(typ))
for k,v in kwargs.items(): setattr(s.contents, k, v)
@@ -14,20 +16,8 @@ def deserialize(enc_src, opts):
mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src))
return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader)
class NIRCompiler(Compiler):
def __init__(self, cache_key):
mesa.glsl_type_singleton_init_or_ref()
super().__init__(cache_key)
def __del__(self): mesa.glsl_type_singleton_decref()
class LVPCompiler(CPULLVMCompiler, NIRCompiler):
def __init__(self, cache_key="lvp"):
CPULLVMCompiler.__init__(self)
NIRCompiler.__init__(self, f"compile_{cache_key}")
def __del__(self):
NIRCompiler.__del__(self)
CPULLVMCompiler.__del__(self)
class LVPCompiler(CPULLVMCompiler):
def __init__(self, cache_key="lvp"): CPULLVMCompiler.__init__(self, cache_key=f"compile_{cache_key}")
def compile(self, src) -> bytes:
shader, ctx = deserialize(src, mesa.lvp_nir_options), llvm.LLVMGetGlobalContext()
@@ -60,16 +50,14 @@ class LVPCompiler(CPULLVMCompiler, NIRCompiler):
def disassemble(self, lib: bytes): cpu_objdump(lib)
class NAKCompiler(NIRCompiler):
class NAKCompiler(Compiler):
def __init__(self, arch, warps_per_sm, cache_key="nak"):
self.arch, self.warps_per_sm = arch, warps_per_sm
self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm))
self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents)
super().__init__(f"compile_{cache_key}_{arch}")
def __del__(self):
mesa.nak_compiler_destroy(self.cc)
super().__del__()
def __del__(self): mesa.nak_compiler_destroy(self.cc)
def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm)
@@ -100,7 +88,7 @@ def disas_adreno(lib:bytes, gpu_id=630):
tf.seek(0)
print(tf.read())
class IR3Compiler(NIRCompiler):
class IR3Compiler(Compiler):
def __init__(self, chip_id, cache_key="ir3"):
assert sys.version_info >= (3,14), "IR3 requires python 3.14's bitfield fixes"
self.dev_id = mesa.struct_fd_dev_id(((chip_id >> 24) & 0xFF) * 100 + ((chip_id >> 16) & 0xFF) * 10 + ((chip_id >> 8) & 0xFF), chip_id)
@@ -110,9 +98,7 @@ class IR3Compiler(NIRCompiler):
self.nir_options = bytes(mesa.ir3_get_compiler_options(self.cc).contents)
super().__init__(f"compile_{cache_key}")
def __del__(self):
mesa.ir3_compiler_destroy(self.cc)
super().__del__()
def __del__(self): mesa.ir3_compiler_destroy(self.cc)
def __reduce__(self): return IR3Compiler, (self.dev_id.chip_id,)