mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
This reverts commit 72236bbd3d.
This commit is contained in:
committed by
GitHub
parent
94bca91f3e
commit
fc4faed0b2
@@ -1,6 +1,6 @@
|
||||
from typing import Callable, cast, Any
|
||||
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes
|
||||
from tinygrad.helpers import DEBUG, OSX, unwrap, charptr, fromimport
|
||||
from tinygrad.helpers import DEBUG, OSX, unwrap, charptr
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
|
||||
@@ -115,8 +115,7 @@ def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
|
||||
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
|
||||
|
||||
class NIRRenderer(Renderer):
|
||||
suffix = "NIR"
|
||||
nir_options: bytes
|
||||
suffix = "NAK"
|
||||
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
||||
code_for_op = {**{k:lambda:None for k in u_aop.keys()}, **{k:lambda:None for k in s_aop.keys()}, **{k:lambda:None for k in f_aop.keys()}}
|
||||
|
||||
@@ -159,14 +158,13 @@ class NIRRenderer(Renderer):
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
|
||||
])
|
||||
|
||||
def __init__(self, compiler):
|
||||
self.compiler = compiler
|
||||
if hasattr(self.compiler, "nir_options"): self.nir_options = self.compiler.nir_options
|
||||
mesa.glsl_type_singleton_init_or_ref()
|
||||
def __init__(self): mesa.glsl_type_singleton_init_or_ref()
|
||||
|
||||
def __del__(self):
|
||||
with contextlib.suppress(AttributeError): mesa.glsl_type_singleton_decref()
|
||||
|
||||
@property
|
||||
def nir_options(self): raise NotImplementedError("needs nir_options")
|
||||
def param(self, b:mesa.nir_builder, x, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param")
|
||||
def prerender(self, uops:list[UOp]):
|
||||
self.b = mesa.nir_builder_init_simple_shader(mesa.MESA_SHADER_COMPUTE, mesa.nir_shader_compiler_options.from_buffer_copy(self.nir_options), None)
|
||||
@@ -218,11 +216,20 @@ class NIRRenderer(Renderer):
|
||||
|
||||
return ret
|
||||
|
||||
class NAKRenderer(NIRRenderer):
|
||||
class NIRRendererWithOpts(NIRRenderer):
|
||||
def __init__(self, dev=None, nir_options=None):
|
||||
self.dev, self._nir_options = dev, nir_options
|
||||
super().__init__()
|
||||
|
||||
def __reduce__(self): return self.__class__, (None, self.nir_options)
|
||||
|
||||
@property
|
||||
def nir_options(self):
|
||||
if self._nir_options is None: self._nir_options = self.dev.compiler.nir_options
|
||||
return self._nir_options
|
||||
|
||||
class NAKRenderer(NIRRendererWithOpts):
|
||||
device = "NV"
|
||||
|
||||
def __init__(self, arch, warps_per_sm): super().__init__(fromimport("tinygrad.runtime.support.compiler_mesa", "NAKCompiler")(arch, warps_per_sm))
|
||||
|
||||
param = nir_instr(nc=1, num_components=1, bs=lambda sz:sz*8, also=lambda self,sz: setattr(self, "param_idx", self.param_idx + sz),
|
||||
intrins={"ALIGN_MUL":lambda sz:sz}, srcs=lambda self,b: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))])(
|
||||
lambda self, b, x, sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_ldc_nv))
|
||||
@@ -234,8 +241,6 @@ class LVPRenderer(NIRRenderer):
|
||||
global_max = (1, 0, 0)
|
||||
nir_options = mesa.lvp_nir_options
|
||||
|
||||
def __init__(self): super().__init__(fromimport("tinygrad.runtime.support.compiler_mesa", "LVPCompiler")())
|
||||
|
||||
param = nir_instr(nc=1, bs=lambda sz: sz * 8, num_components=1, intrins={"ALIGN_MUL":lambda sz: sz, "RANGE":lambda self: self.param_sz},
|
||||
srcs=lambda b, self: [nsrc(nimm(b, 0, dtypes.int)), nsrc(nimm(b, self.param_idx, dtypes.int))], also=lambda self, sz:
|
||||
setattr(self, "param_idx", self.param_idx+sz))(lambda self,b,x,sz: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_ubo))
|
||||
@@ -256,11 +261,9 @@ _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2
|
||||
nc=4, bs=32, num_components=4, srcs=lambda b,img,coord:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,coord,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
|
||||
|
||||
class IR3Renderer(NIRRenderer):
|
||||
class IR3Renderer(NIRRendererWithOpts):
|
||||
device = "QCOM"
|
||||
|
||||
def __init__(self, chip_id): super().__init__(fromimport("tinygrad.runtime.support.compiler_mesa", "IR3Compiler")(chip_id))
|
||||
|
||||
def nload_img(ctx,img,coord):
|
||||
ctx.texs.add(img)
|
||||
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype)
|
||||
|
||||
@@ -9,6 +9,7 @@ from tinygrad.renderer.cstyle import ClangJITRenderer
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer
|
||||
from tinygrad.renderer.nir import LVPRenderer
|
||||
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
|
||||
from tinygrad.runtime.support.compiler_mesa import LVPCompiler
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
from tinygrad.uop.ops import sint
|
||||
|
||||
@@ -71,7 +72,7 @@ class CPUProgram(HCQProgram):
|
||||
except OSError: pass
|
||||
|
||||
def __init__(self, dev, name:str, lib:bytes):
|
||||
LVP = isinstance(dev.renderer, LVPRenderer)
|
||||
LVP = isinstance(dev.compiler, LVPCompiler)
|
||||
if sys.platform == "win32": # mypy doesn't understand when WIN is used here
|
||||
PAGE_EXECUTE_READWRITE, MEM_COMMIT, MEM_RESERVE = 0x40, 0x1000, 0x2000
|
||||
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
|
||||
@@ -136,5 +137,5 @@ class CPUDevice(HCQCompiled):
|
||||
self.tasks:queue.Queue = queue.Queue()
|
||||
CPUWorker(self, self.tasks, thread_id=0).start()
|
||||
compilers = CompilerSet([CompilerPair(ClangJITRenderer, None), CompilerPair(LLVMRenderer, CPULLVMCompiler, ctrl_var=CPU_LLVM),
|
||||
CompilerPair(LVPRenderer, None, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC)
|
||||
CompilerPair(LVPRenderer, LVPCompiler, ctrl_var=CPU_LVP)], ctrl_var=CPU_CC)
|
||||
super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
|
||||
|
||||
@@ -6,6 +6,7 @@ from tinygrad.renderer.llvmir import AMDLLVMRenderer
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.helpers import cpu_profile, EMULATE, NULL_IR3, NULL_NAK
|
||||
from tinygrad.renderer.nir import IR3Renderer, NAKRenderer
|
||||
from tinygrad.runtime.support.compiler_mesa import IR3Compiler, NAKCompiler
|
||||
|
||||
class NullRenderer(CStyleLanguage):
|
||||
device = "NULL"
|
||||
@@ -38,6 +39,7 @@ class NullDevice(Compiled):
|
||||
case "AMD_RDNA4": renderer = functools.partial(AMDLLVMRenderer, "gfx1201")
|
||||
case "": renderer = NullRenderer
|
||||
case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}")
|
||||
compilers = CompilerSet([CompilerPair(renderer, Compiler), CompilerPair(functools.partial(IR3Renderer, 0x6030001), None, NULL_IR3), # adreno 630
|
||||
CompilerPair(functools.partial(NAKRenderer, "sm_120", 48), None, NULL_NAK)]) # 5090
|
||||
compilers = CompilerSet([CompilerPair(renderer, Compiler),
|
||||
CompilerPair(functools.partial(IR3Renderer, self), functools.partial(IR3Compiler, 0x6030001), NULL_IR3), # adreno 630
|
||||
CompilerPair(functools.partial(NAKRenderer, self), functools.partial(NAKCompiler, "sm_120", 48), NULL_NAK)]) # 5090
|
||||
super().__init__(device, NullAllocator(self), compilers, functools.partial(NullProgram, device), NullGraph)
|
||||
|
||||
@@ -11,6 +11,7 @@ from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, pr
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.cstyle import NVRenderer
|
||||
from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, NVPTXCompiler, NVCompiler
|
||||
from tinygrad.runtime.support.compiler_mesa import NAKCompiler
|
||||
from tinygrad.runtime.autogen import nv_570, nv_580, pci, mesa
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager
|
||||
@@ -215,7 +216,7 @@ class NVProgram(HCQProgram):
|
||||
self.dev, self.name, self.lib = dev, name, lib
|
||||
self.constbufs: dict[int, tuple[int, int]] = {0: (0, 0x160)} # dict[constbuf index, tuple[va_addr, size]]
|
||||
|
||||
if (NAK:=isinstance(dev.renderer, NAKRenderer)):
|
||||
if (NAK:=isinstance(dev.compiler, NAKCompiler)):
|
||||
image, self.cbuf_0 = memoryview(bytearray(lib[ctypes.sizeof(info:=mesa.struct_nak_shader_info.from_buffer_copy(lib)):])), []
|
||||
self.regs_usage, self.shmem_usage, self.lcmem_usage = info.num_gprs, round_up(info.cs.smem_size, 128), round_up(info.slm_size, 16)
|
||||
elif MOCKGPU: image, sections, relocs = memoryview(bytearray(lib) + b'\x00' * (4 - len(lib)%4)).cast("I"), [], [] # type: ignore
|
||||
@@ -585,7 +586,7 @@ class NVDevice(HCQCompiled[HCQSignal]):
|
||||
cucc, ptxcc = (CUDACompiler, PTXCompiler) if MOCKGPU else (NVCompiler, NVPTXCompiler)
|
||||
compilers = CompilerSet(ctrl_var=NV_CC, cset=[CompilerPair(functools.partial(NVRenderer, self.arch),functools.partial(cucc, self.arch)),
|
||||
CompilerPair(functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(ptxcc, self.arch), NV_PTX),
|
||||
CompilerPair(functools.partial(NAKRenderer, self.arch, self.max_warps_per_sm), None, NV_NAK)])
|
||||
CompilerPair(functools.partial(NAKRenderer, dev=self), functools.partial(NAKCompiler, self.arch, self.max_warps_per_sm), NV_NAK)])
|
||||
super().__init__(device, NVAllocator(self), compilers, functools.partial(NVProgram, self), HCQSignal, NVComputeQueue, NVCopyQueue)
|
||||
|
||||
self._setup_gpfifos()
|
||||
|
||||
@@ -10,6 +10,7 @@ from tinygrad.runtime.autogen import kgsl, mesa
|
||||
from tinygrad.runtime.ops_cl import CLCompiler, CLDevice
|
||||
from tinygrad.renderer.cstyle import QCOMRenderer
|
||||
from tinygrad.renderer.nir import IR3Renderer
|
||||
from tinygrad.runtime.support.compiler_mesa import IR3Compiler
|
||||
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing
|
||||
from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC
|
||||
from tinygrad.runtime.support.system import System
|
||||
@@ -226,10 +227,10 @@ class IR3ArgsState(HCQArgsState):
|
||||
class QCOMProgram(HCQProgram):
|
||||
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
|
||||
self.dev: QCOMDevice = dev
|
||||
self.name, self.lib, self.NIR = name, lib, isinstance(dev.renderer, IR3Renderer)
|
||||
self.name, self.lib, self.NIR = name, lib, isinstance(dev.compiler, IR3Compiler)
|
||||
|
||||
if self.NIR:
|
||||
from tinygrad.runtime.support.compiler_mesa import IR3Compiler
|
||||
from tinygrad.runtime.autogen import mesa
|
||||
v, cs, self.imm_vals, self.image = IR3Compiler.unpack_lib(lib)
|
||||
self.prg_offset, self.brnchstck, self.image_size, self.pvtmem, self.shmem = 0, v.branchstack, v.info.size, v.pvtmem_size, v.shared_size
|
||||
self.wgsz = alloc.offset_vec4 * 4 + 8 if (alloc:=cs.allocs.consts[mesa.IR3_CONST_ALLOC_DRIVER_PARAMS]).size_vec4 else 0xfc
|
||||
@@ -401,7 +402,7 @@ class QCOMDevice(HCQCompiled):
|
||||
System.write_sysfs("/sys/class/kgsl/kgsl-3d0/idle_timer", value="4000000000", msg="Failed to disable suspend mode", expected="4294967276")
|
||||
|
||||
compilers = CompilerSet(ctrl_var=QCOM_CC, cset=[CompilerPair(QCOMRenderer, functools.partial(QCOMCompiler, device)),
|
||||
CompilerPair(functools.partial(IR3Renderer, info.chip_id), None, QCOM_IR3)])
|
||||
CompilerPair(functools.partial(IR3Renderer, self), functools.partial(IR3Compiler, info.chip_id), QCOM_IR3)])
|
||||
super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal,
|
||||
functools.partial(QCOMComputeQueue, self), None)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def expect(x, err, ret=None):
|
||||
class LLVMCompiler(Compiler):
|
||||
jit = True
|
||||
target_arch = {'arm64': 'AArch64', 'aarch64': 'AArch64', 'x86_64': 'X86', 'AMD64': 'X86', 'riscv64': 'riscv64'}[platform.machine()]
|
||||
def __init__(self, processor:str, feats:str, cache_key=None):
|
||||
def __init__(self, processor:str, feats:str):
|
||||
for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmParser', 'AsmPrinter']: getattr(llvm, f'LLVMInitialize{self.target_arch}{component}')()
|
||||
|
||||
triple = {'AArch64': b'aarch64-none-unknown-elf', 'X86': b'x86_64-none-unknown-elf', 'AMDGPU': b'amdgcn-amd-amdhsa'}[self.target_arch]
|
||||
@@ -59,7 +59,7 @@ class LLVMCompiler(Compiler):
|
||||
self.diag_msgs.append(msg)
|
||||
self.handle_diag = handle_diag
|
||||
llvm.LLVMContextSetDiagnosticHandler(self.context, handle_diag, None)
|
||||
super().__init__(cache_key or f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
|
||||
super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
|
||||
|
||||
def __del__(self):
|
||||
llvm.LLVMDisposePassBuilderOptions(self.pbo)
|
||||
@@ -83,7 +83,7 @@ class LLVMCompiler(Compiler):
|
||||
def disassemble(self, lib:bytes): capstone_flatdump(lib)
|
||||
|
||||
class CPULLVMCompiler(LLVMCompiler):
|
||||
def __init__(self, cache_key=None):
|
||||
def __init__(self):
|
||||
# +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx
|
||||
cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures())
|
||||
super().__init__(cpu.decode(), feats.decode(), cache_key)
|
||||
super().__init__(cpu.decode(), feats.decode())
|
||||
|
||||
@@ -4,8 +4,6 @@ from tinygrad.helpers import cpu_objdump, system, data64
|
||||
from tinygrad.runtime.autogen import mesa, llvm
|
||||
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, expect, cerr
|
||||
|
||||
# NB: compilers assume mesa's glsl type cache is managed externally with mesa.glsl_type_singleton_init_or_ref() and mesa.glsl_type_singleton_decref()
|
||||
|
||||
def rzalloc(typ, ctx=None, **kwargs):
|
||||
s = ctypes.cast(mesa.rzalloc_size(ctypes.cast(ctx, ctypes.c_void_p), ctypes.sizeof(typ)), ctypes.POINTER(typ))
|
||||
for k,v in kwargs.items(): setattr(s.contents, k, v)
|
||||
@@ -16,8 +14,20 @@ def deserialize(enc_src, opts):
|
||||
mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src))
|
||||
return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader)
|
||||
|
||||
class LVPCompiler(CPULLVMCompiler):
|
||||
def __init__(self, cache_key="lvp"): CPULLVMCompiler.__init__(self, cache_key=f"compile_{cache_key}")
|
||||
class NIRCompiler(Compiler):
|
||||
def __init__(self, cache_key):
|
||||
mesa.glsl_type_singleton_init_or_ref()
|
||||
super().__init__(cache_key)
|
||||
def __del__(self): mesa.glsl_type_singleton_decref()
|
||||
|
||||
class LVPCompiler(CPULLVMCompiler, NIRCompiler):
|
||||
def __init__(self, cache_key="lvp"):
|
||||
CPULLVMCompiler.__init__(self)
|
||||
NIRCompiler.__init__(self, f"compile_{cache_key}")
|
||||
|
||||
def __del__(self):
|
||||
NIRCompiler.__del__(self)
|
||||
CPULLVMCompiler.__del__(self)
|
||||
|
||||
def compile(self, src) -> bytes:
|
||||
shader, ctx = deserialize(src, mesa.lvp_nir_options), llvm.LLVMGetGlobalContext()
|
||||
@@ -50,14 +60,16 @@ class LVPCompiler(CPULLVMCompiler):
|
||||
|
||||
def disassemble(self, lib: bytes): cpu_objdump(lib)
|
||||
|
||||
class NAKCompiler(Compiler):
|
||||
class NAKCompiler(NIRCompiler):
|
||||
def __init__(self, arch, warps_per_sm, cache_key="nak"):
|
||||
self.arch, self.warps_per_sm = arch, warps_per_sm
|
||||
self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm))
|
||||
self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents)
|
||||
super().__init__(f"compile_{cache_key}_{arch}")
|
||||
|
||||
def __del__(self): mesa.nak_compiler_destroy(self.cc)
|
||||
def __del__(self):
|
||||
mesa.nak_compiler_destroy(self.cc)
|
||||
super().__del__()
|
||||
|
||||
def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm)
|
||||
|
||||
@@ -88,7 +100,7 @@ def disas_adreno(lib:bytes, gpu_id=630):
|
||||
tf.seek(0)
|
||||
print(tf.read())
|
||||
|
||||
class IR3Compiler(Compiler):
|
||||
class IR3Compiler(NIRCompiler):
|
||||
def __init__(self, chip_id, cache_key="ir3"):
|
||||
assert sys.version_info >= (3,14), "IR3 requires python 3.14's bitfield fixes"
|
||||
self.dev_id = mesa.struct_fd_dev_id(((chip_id >> 24) & 0xFF) * 100 + ((chip_id >> 16) & 0xFF) * 10 + ((chip_id >> 8) & 0xFF), chip_id)
|
||||
@@ -98,7 +110,9 @@ class IR3Compiler(Compiler):
|
||||
self.nir_options = bytes(mesa.ir3_get_compiler_options(self.cc).contents)
|
||||
super().__init__(f"compile_{cache_key}")
|
||||
|
||||
def __del__(self): mesa.ir3_compiler_destroy(self.cc)
|
||||
def __del__(self):
|
||||
mesa.ir3_compiler_destroy(self.cc)
|
||||
super().__del__()
|
||||
|
||||
def __reduce__(self): return IR3Compiler, (self.dev_id.chip_id,)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user