NIR: new-style (fix beam) (#13887)

* NIR: fix beam

* new reduce

* Revert "Revert "NIR: new-style compilers (#13875)" (#13888)"

This reverts commit fc4faed0b2.

* oops
This commit is contained in:
Christopher Milan
2025-12-29 15:41:29 -08:00
committed by GitHub
parent fc4faed0b2
commit 0497387e45
7 changed files with 34 additions and 59 deletions

View File

@@ -1,6 +1,6 @@
from typing import Callable, cast, Any from typing import Callable, cast, Any
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes
from tinygrad.helpers import DEBUG, OSX, unwrap, charptr from tinygrad.helpers import DEBUG, OSX, unwrap, charptr, fromimport
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
@@ -115,7 +115,8 @@ def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
return if_phi(b, gate, f, lambda: buf) if gate is not None else f() return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
class NIRRenderer(Renderer): class NIRRenderer(Renderer):
suffix = "NAK" suffix = "NIR"
nir_options: bytes
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
code_for_op = {**{k:lambda:None for k in u_aop.keys()}, **{k:lambda:None for k in s_aop.keys()}, **{k:lambda:None for k in f_aop.keys()}} code_for_op = {**{k:lambda:None for k in u_aop.keys()}, **{k:lambda:None for k in s_aop.keys()}, **{k:lambda:None for k in f_aop.keys()}}
@@ -158,13 +159,17 @@ class NIRRenderer(Renderer):
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]]))) (UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
]) ])
def __init__(self): mesa.glsl_type_singleton_init_or_ref() def __reduce__(self): return self.__class__, self.args
def __init__(self, *args):
self.compiler = fromimport("tinygrad.runtime.support.compiler_mesa", self.__class__.__name__.replace("Renderer", "Compiler"))(*args)
self.args = args
if hasattr(self.compiler, "nir_options"): self.nir_options = self.compiler.nir_options
mesa.glsl_type_singleton_init_or_ref()
def __del__(self): def __del__(self):
with contextlib.suppress(AttributeError): mesa.glsl_type_singleton_decref() with contextlib.suppress(AttributeError): mesa.glsl_type_singleton_decref()
@property
def nir_options(self): raise NotImplementedError("needs nir_options")
def param(self, b:mesa.nir_builder, x, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param") def param(self, b:mesa.nir_builder, x, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param")
def prerender(self, uops:list[UOp]): def prerender(self, uops:list[UOp]):
self.b = mesa.nir_builder_init_simple_shader(mesa.MESA_SHADER_COMPUTE, mesa.nir_shader_compiler_options.from_buffer_copy(self.nir_options), None) self.b = mesa.nir_builder_init_simple_shader(mesa.MESA_SHADER_COMPUTE, mesa.nir_shader_compiler_options.from_buffer_copy(self.nir_options), None)
@@ -216,20 +221,9 @@ class NIRRenderer(Renderer):
return ret return ret
class NIRRendererWithOpts(NIRRenderer): class NAKRenderer(NIRRenderer):
def __init__(self, dev=None, nir_options=None):
self.dev, self._nir_options = dev, nir_options
super().__init__()
def __reduce__(self): return self.__class__, (None, self.nir_options)
@property
def nir_options(self):
if self._nir_options is None: self._nir_options = self.dev.compiler.nir_options
return self._nir_options
class NAKRenderer(NIRRendererWithOpts):
device = "NV" device = "NV"
param = nir_instr(nc=1, num_components=1, bs=lambda sz:sz*8, also=lambda self,sz: setattr(self, "param_idx", self.param_idx + sz), param = nir_instr(nc=1, num_components=1, bs=lambda sz:sz*8, also=lambda self,sz: setattr(self, "param_idx", self.param_idx + sz),
intrins={"ALIGN_MUL":lambda sz:sz}, srcs=lambda self,b: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))])( intrins={"ALIGN_MUL":lambda sz:sz}, srcs=lambda self,b: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))])(
lambda self, b, x, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_ldc_nv)) lambda self, b, x, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_ldc_nv))
@@ -261,7 +255,7 @@ _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2
nc=4, bs=32, num_components=4, srcs=lambda b,img,coord:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])( nc=4, bs=32, num_components=4, srcs=lambda b,img,coord:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
lambda b,img,coord,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load"))) lambda b,img,coord,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
class IR3Renderer(NIRRendererWithOpts): class IR3Renderer(NIRRenderer):
device = "QCOM" device = "QCOM"
def nload_img(ctx,img,coord): def nload_img(ctx,img,coord):

View File

@@ -9,7 +9,6 @@ from tinygrad.renderer.cstyle import ClangJITRenderer
from tinygrad.renderer.llvmir import LLVMRenderer from tinygrad.renderer.llvmir import LLVMRenderer
from tinygrad.renderer.nir import LVPRenderer from tinygrad.renderer.nir import LVPRenderer
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
from tinygrad.runtime.support.compiler_mesa import LVPCompiler
from tinygrad.runtime.support.elf import jit_loader from tinygrad.runtime.support.elf import jit_loader
from tinygrad.uop.ops import sint from tinygrad.uop.ops import sint
@@ -72,7 +71,7 @@ class CPUProgram(HCQProgram):
except OSError: pass except OSError: pass
def __init__(self, dev, name:str, lib:bytes): def __init__(self, dev, name:str, lib:bytes):
LVP = isinstance(dev.compiler, LVPCompiler) LVP = isinstance(dev.renderer, LVPRenderer)
if sys.platform == "win32": # mypy doesn't understand when WIN is used here if sys.platform == "win32": # mypy doesn't understand when WIN is used here
PAGE_EXECUTE_READWRITE, MEM_COMMIT, MEM_RESERVE = 0x40, 0x1000, 0x2000 PAGE_EXECUTE_READWRITE, MEM_COMMIT, MEM_RESERVE = 0x40, 0x1000, 0x2000
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
@@ -137,5 +136,5 @@ class CPUDevice(HCQCompiled):
self.tasks:queue.Queue = queue.Queue() self.tasks:queue.Queue = queue.Queue()
CPUWorker(self, self.tasks, thread_id=0).start() CPUWorker(self, self.tasks, thread_id=0).start()
compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM), compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
CompilerPair(LVPRenderer, LVPCompiler, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC) CompilerPair(LVPRenderer, None, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC)
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue) super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)

View File

@@ -6,7 +6,6 @@ from tinygrad.renderer.llvmir import AMDLLVMRenderer
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops
from tinygrad.helpers import cpu_profile, EMULATE, NULL_IR3, NULL_NAK from tinygrad.helpers import cpu_profile, EMULATE, NULL_IR3, NULL_NAK
from tinygrad.renderer.nir import IR3Renderer, NAKRenderer from tinygrad.renderer.nir import IR3Renderer, NAKRenderer
from tinygrad.runtime.support.compiler_mesa import IR3Compiler, NAKCompiler
class NullRenderer(CStyleLanguage): class NullRenderer(CStyleLanguage):
device = "NULL" device = "NULL"
@@ -39,7 +38,6 @@ class NullDevice(Compiled):
case "AMD_RDNA4": renderer = functools.partial(AMDLLVMRenderer, "gfx1201") case "AMD_RDNA4": renderer = functools.partial(AMDLLVMRenderer, "gfx1201")
case "": renderer = NullRenderer case "": renderer = NullRenderer
case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}") case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}")
compilers = CompilerSet([CompilerPair(renderer, Compiler), compilers = CompilerSet([CompilerPair(renderer, Compiler), CompilerPair(functools.partial(IR3Renderer, 0x6030001), None, NULL_IR3), # adreno 630
CompilerPair(functools.partial(IR3Renderer, self), functools.partial(IR3Compiler, 0x6030001), NULL_IR3), # adreno 630 CompilerPair(functools.partial(NAKRenderer, "sm_120", 48), None, NULL_NAK)]) # 5090
CompilerPair(functools.partial(NAKRenderer, self), functools.partial(NAKCompiler, "sm_120", 48), NULL_NAK)]) # 5090
super().__init__(device, NullAllocator(self), compilers, functools.partial(NullProgram, device), NullGraph) super().__init__(device, NullAllocator(self), compilers, functools.partial(NullProgram, device), NullGraph)

View File

@@ -11,7 +11,6 @@ from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, pr
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import NVRenderer from tinygrad.renderer.cstyle import NVRenderer
from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, NVPTXCompiler, NVCompiler from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, NVPTXCompiler, NVCompiler
from tinygrad.runtime.support.compiler_mesa import NAKCompiler
from tinygrad.runtime.autogen import nv_570, nv_580, pci, mesa from tinygrad.runtime.autogen import nv_570, nv_580, pci, mesa
from tinygrad.runtime.support.elf import elf_loader from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager
@@ -216,7 +215,7 @@ class NVProgram(HCQProgram):
self.dev, self.name, self.lib = dev, name, lib self.dev, self.name, self.lib = dev, name, lib
self.constbufs: dict[int, tuple[int, int]] = {0: (0, 0x160)} # dict[constbuf index, tuple[va_addr, size]] self.constbufs: dict[int, tuple[int, int]] = {0: (0, 0x160)} # dict[constbuf index, tuple[va_addr, size]]
if (NAK:=isinstance(dev.compiler, NAKCompiler)): if (NAK:=isinstance(dev.renderer, NAKRenderer)):
image, self.cbuf_0 = memoryview(bytearray(lib[ctypes.sizeof(info:=mesa.struct_nak_shader_info.from_buffer_copy(lib)):])), [] image, self.cbuf_0 = memoryview(bytearray(lib[ctypes.sizeof(info:=mesa.struct_nak_shader_info.from_buffer_copy(lib)):])), []
self.regs_usage, self.shmem_usage, self.lcmem_usage = info.num_gprs, round_up(info.cs.smem_size, 128), round_up(info.slm_size, 16) self.regs_usage, self.shmem_usage, self.lcmem_usage = info.num_gprs, round_up(info.cs.smem_size, 128), round_up(info.slm_size, 16)
elif MOCKGPU: image, sections, relocs = memoryview(bytearray(lib) + b'\x00' * (4 - len(lib)%4)).cast("I"), [], [] # type: ignore elif MOCKGPU: image, sections, relocs = memoryview(bytearray(lib) + b'\x00' * (4 - len(lib)%4)).cast("I"), [], [] # type: ignore
@@ -586,7 +585,7 @@ class NVDevice(HCQCompiled[HCQSignal]):
cucc, ptxcc = (CUDACompiler, PTXCompiler) if MOCKGPU else (NVCompiler, NVPTXCompiler) cucc, ptxcc = (CUDACompiler, PTXCompiler) if MOCKGPU else (NVCompiler, NVPTXCompiler)
compilers = CompilerSet(ctrl_var=NV_CC, cset=[CompilerPair(functools.partial(NVRenderer, self.arch),functools.partial(cucc, self.arch)), compilers = CompilerSet(ctrl_var=NV_CC, cset=[CompilerPair(functools.partial(NVRenderer, self.arch),functools.partial(cucc, self.arch)),
CompilerPair(functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(ptxcc, self.arch), NV_PTX), CompilerPair(functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(ptxcc, self.arch), NV_PTX),
CompilerPair(functools.partial(NAKRenderer, dev=self), functools.partial(NAKCompiler, self.arch, self.max_warps_per_sm), NV_NAK)]) CompilerPair(functools.partial(NAKRenderer, self.arch, self.max_warps_per_sm), None, NV_NAK)])
super().__init__(device, NVAllocator(self), compilers, functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue) super().__init__(device, NVAllocator(self), compilers, functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue)
self._setup_gpfifos() self._setup_gpfifos()

View File

@@ -10,7 +10,6 @@ from tinygrad.runtime.autogen import kgsl, mesa
from tinygrad.runtime.ops_cl import CLCompiler, CLDevice from tinygrad.runtime.ops_cl import CLCompiler, CLDevice
from tinygrad.renderer.cstyle import QCOMRenderer from tinygrad.renderer.cstyle import QCOMRenderer
from tinygrad.renderer.nir import IR3Renderer from tinygrad.renderer.nir import IR3Renderer
from tinygrad.runtime.support.compiler_mesa import IR3Compiler
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing
from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC
from tinygrad.runtime.support.system import System from tinygrad.runtime.support.system import System
@@ -227,10 +226,10 @@ class IR3ArgsState(HCQArgsState):
class QCOMProgram(HCQProgram): class QCOMProgram(HCQProgram):
def __init__(self, dev: QCOMDevice, name: str, lib: bytes): def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
self.dev: QCOMDevice = dev self.dev: QCOMDevice = dev
self.name, self.lib, self.NIR = name, lib, isinstance(dev.compiler, IR3Compiler) self.name, self.lib, self.NIR = name, lib, isinstance(dev.renderer, IR3Renderer)
if self.NIR: if self.NIR:
from tinygrad.runtime.autogen import mesa from tinygrad.runtime.support.compiler_mesa import IR3Compiler
v, cs, self.imm_vals, self.image = IR3Compiler.unpack_lib(lib) v, cs, self.imm_vals, self.image = IR3Compiler.unpack_lib(lib)
self.prg_offset, self.brnchstck, self.image_size, self.pvtmem, self.shmem = 0, v.branchstack, v.info.size, v.pvtmem_size, v.shared_size self.prg_offset, self.brnchstck, self.image_size, self.pvtmem, self.shmem = 0, v.branchstack, v.info.size, v.pvtmem_size, v.shared_size
self.wgsz = alloc.offset_vec4 * 4 + 8 if (alloc:=cs.allocs.consts[mesa.IR3_CONST_ALLOC_DRIVER_PARAMS]).size_vec4 else 0xfc self.wgsz = alloc.offset_vec4 * 4 + 8 if (alloc:=cs.allocs.consts[mesa.IR3_CONST_ALLOC_DRIVER_PARAMS]).size_vec4 else 0xfc
@@ -402,7 +401,7 @@ class QCOMDevice(HCQCompiled):
System.write_sysfs("/sys/class/kgsl/kgsl-3d0/idle_timer", value="4000000000", msg="Failed to disable suspend mode", expected="4294967276") System.write_sysfs("/sys/class/kgsl/kgsl-3d0/idle_timer", value="4000000000", msg="Failed to disable suspend mode", expected="4294967276")
compilers = CompilerSet(ctrl_var=QCOM_CC, cset=[CompilerPair(QCOMRenderer, functools.partial(QCOMCompiler, device)), compilers = CompilerSet(ctrl_var=QCOM_CC, cset=[CompilerPair(QCOMRenderer, functools.partial(QCOMCompiler, device)),
CompilerPair(functools.partial(IR3Renderer, self), functools.partial(IR3Compiler, info.chip_id), QCOM_IR3)]) CompilerPair(functools.partial(IR3Renderer, info.chip_id), None, QCOM_IR3)])
super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal, super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal,
functools.partial(QCOMComputeQueue, self), None) functools.partial(QCOMComputeQueue, self), None)

View File

@@ -29,7 +29,7 @@ def expect(x, err, ret=None):
class LLVMCompiler(Compiler): class LLVMCompiler(Compiler):
jit = True jit = True
target_arch = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86', 'riscv64': 'riscv64'}[platform.machine()] target_arch = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86', 'riscv64': 'riscv64'}[platform.machine()]
def __init__(self, processor:str, feats:str): def __init__(self, processor:str, feats:str, cache_key=None):
for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmParser', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{self.target_arch}{component}')() for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmParser', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{self.target_arch}{component}')()
triple = {'AArch64': b'aarch64-none-unknown-elf', 'X86': b'x86_64-none-unknown-elf', 'AMDGPU': b'amdgcn-amd-amdhsa'}[self.target_arch] triple = {'AArch64': b'aarch64-none-unknown-elf', 'X86': b'x86_64-none-unknown-elf', 'AMDGPU': b'amdgcn-amd-amdhsa'}[self.target_arch]
@@ -59,7 +59,7 @@ class LLVMCompiler(Compiler):
self.diag_msgs.append(msg) self.diag_msgs.append(msg)
self.handle_diag = handle_diag self.handle_diag = handle_diag
llvm.LLVMContextSetDiagnosticHandler(self.context, handle_diag, None) llvm.LLVMContextSetDiagnosticHandler(self.context, handle_diag, None)
super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}") super().__init__(cache_key or f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
def __del__(self): def __del__(self):
llvm.LLVMDisposePassBuilderOptions(self.pbo) llvm.LLVMDisposePassBuilderOptions(self.pbo)
@@ -83,7 +83,7 @@ class LLVMCompiler(Compiler):
def disassemble(self, lib:bytes): capstone_flatdump(lib) def disassemble(self, lib:bytes): capstone_flatdump(lib)
class CPULLVMCompiler(LLVMCompiler): class CPULLVMCompiler(LLVMCompiler):
def __init__(self): def __init__(self, cache_key=None):
# +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx # +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx
cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures()) cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures())
super().__init__(cpu.decode(), feats.decode()) super().__init__(cpu.decode(), feats.decode(), cache_key)

View File

@@ -4,6 +4,8 @@ from tinygrad.helpers import cpu_objdump, system, data64
from tinygrad.runtime.autogen import mesa, llvm from tinygrad.runtime.autogen import mesa, llvm
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, expect, cerr from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, expect, cerr
# NB: compilers assume mesa's glsl type cache is managed externally with mesa.glsl_type_singleton_init_or_ref() and mesa.glsl_type_singleton_decref()
def rzalloc(typ, ctx=None, **kwargs): def rzalloc(typ, ctx=None, **kwargs):
s = ctypes.cast(mesa.rzalloc_size(ctypes.cast(ctx, ctypes.c_void_p), ctypes.sizeof(typ)), ctypes.POINTER(typ)) s = ctypes.cast(mesa.rzalloc_size(ctypes.cast(ctx, ctypes.c_void_p), ctypes.sizeof(typ)), ctypes.POINTER(typ))
for k,v in kwargs.items(): setattr(s.contents, k, v) for k,v in kwargs.items(): setattr(s.contents, k, v)
@@ -14,20 +16,8 @@ def deserialize(enc_src, opts):
mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src)) mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src))
return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader) return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader)
class NIRCompiler(Compiler): class LVPCompiler(CPULLVMCompiler):
def __init__(self, cache_key): def __init__(self, cache_key="lvp"): CPULLVMCompiler.__init__(self, cache_key=f"compile_{cache_key}")
mesa.glsl_type_singleton_init_or_ref()
super().__init__(cache_key)
def __del__(self): mesa.glsl_type_singleton_decref()
class LVPCompiler(CPULLVMCompiler, NIRCompiler):
def __init__(self, cache_key="lvp"):
CPULLVMCompiler.__init__(self)
NIRCompiler.__init__(self, f"compile_{cache_key}")
def __del__(self):
NIRCompiler.__del__(self)
CPULLVMCompiler.__del__(self)
def compile(self, src) -> bytes: def compile(self, src) -> bytes:
shader, ctx = deserialize(src, mesa.lvp_nir_options), llvm.LLVMGetGlobalContext() shader, ctx = deserialize(src, mesa.lvp_nir_options), llvm.LLVMGetGlobalContext()
@@ -60,16 +50,14 @@ class LVPCompiler(CPULLVMCompiler, NIRCompiler):
def disassemble(self, lib: bytes): cpu_objdump(lib) def disassemble(self, lib: bytes): cpu_objdump(lib)
class NAKCompiler(NIRCompiler): class NAKCompiler(Compiler):
def __init__(self, arch, warps_per_sm, cache_key="nak"): def __init__(self, arch, warps_per_sm, cache_key="nak"):
self.arch, self.warps_per_sm = arch, warps_per_sm self.arch, self.warps_per_sm = arch, warps_per_sm
self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm)) self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm))
self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents) self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents)
super().__init__(f"compile_{cache_key}_{arch}") super().__init__(f"compile_{cache_key}_{arch}")
def __del__(self): def __del__(self): mesa.nak_compiler_destroy(self.cc)
mesa.nak_compiler_destroy(self.cc)
super().__del__()
def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm) def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm)
@@ -100,7 +88,7 @@ def disas_adreno(lib:bytes, gpu_id=630):
tf.seek(0) tf.seek(0)
print(tf.read()) print(tf.read())
class IR3Compiler(NIRCompiler): class IR3Compiler(Compiler):
def __init__(self, chip_id, cache_key="ir3"): def __init__(self, chip_id, cache_key="ir3"):
assert sys.version_info >= (3,14), "IR3 requires python 3.14's bitfield fixes" assert sys.version_info >= (3,14), "IR3 requires python 3.14's bitfield fixes"
self.dev_id = mesa.struct_fd_dev_id(((chip_id >> 24) & 0xFF) * 100 + ((chip_id >> 16) & 0xFF) * 10 + ((chip_id >> 8) & 0xFF), chip_id) self.dev_id = mesa.struct_fd_dev_id(((chip_id >> 24) & 0xFF) * 100 + ((chip_id >> 16) & 0xFF) * 10 + ((chip_id >> 8) & 0xFF), chip_id)
@@ -110,9 +98,7 @@ class IR3Compiler(NIRCompiler):
self.nir_options = bytes(mesa.ir3_get_compiler_options(self.cc).contents) self.nir_options = bytes(mesa.ir3_get_compiler_options(self.cc).contents)
super().__init__(f"compile_{cache_key}") super().__init__(f"compile_{cache_key}")
def __del__(self): def __del__(self): mesa.ir3_compiler_destroy(self.cc)
mesa.ir3_compiler_destroy(self.cc)
super().__del__()
def __reduce__(self): return IR3Compiler, (self.dev_id.chip_id,) def __reduce__(self): return IR3Compiler, (self.dev_id.chip_id,)