save more lines

This commit is contained in:
Christopher Milan
2025-10-02 10:18:08 -07:00
parent 75a84b2d04
commit dd3a720c5a
2 changed files with 12 additions and 23 deletions

View File

@@ -179,14 +179,6 @@ class NIRRenderer(Renderer):
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: ensure(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
])
def __init__(self, device):
self.device = device
mesa.glsl_type_singleton_init_or_ref()
def __del__(self):
try: mesa.glsl_type_singleton_decref()
except AttributeError: pass
@property
def nir_options(self): raise NotImplementedError("needs nir_options")
def param(self, dtype:DType, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param")
@@ -194,6 +186,7 @@ class NIRRenderer(Renderer):
self.b = mesa.nir_builder_init_simple_shader(mesa.MESA_SHADER_COMPUTE, mesa.nir_shader_compiler_options.from_buffer_copy(self.nir_options), None)
def render(self, uops:list[UOp]):
mesa.glsl_type_singleton_init_or_ref()
self.prerender(uops)
for u in [u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]: self.b.shader.contents.info.workgroup_size[int(u.arg[-1])] = u.src[0].arg
self.r, self.param_idx, ranges = {}, 0, []
@@ -231,14 +224,16 @@ class NIRRenderer(Renderer):
mesa.ralloc_free(self.b.shader)
ctypes.CDLL(ctypes.util.find_library('c')).free(blob.data)
del self.b, self.r
mesa.glsl_type_singleton_decref()
return ret
class NAKRenderer(NIRRenderer):
def __init__(self, dev=None, nir_options=None, device="NV"):
device = "NV"
def __init__(self, dev=None, nir_options=None):
if dev: self.dev = dev
else: self.__dict__['nir_options'] = nir_options
super().__init__(device)
@classmethod
def with_opts(cls, opts): return cls(nir_options=opts)
@@ -262,13 +257,12 @@ class NAKRenderer(NIRRenderer):
return d(intrin)
class LVPRenderer(NIRRenderer):
device = "CPU"
has_local = False
has_shared = False
global_max = (1, 0, 0)
nir_options = mesa.lvp_nir_options
def __init__(self, device="CPU"): super().__init__(device)
def param(self, dtype:DType, sz:int) -> mesa.nir_def:
intrin = mesa.nir_intrinsic_instr_create(self.b.shader, mesa.nir_intrinsic_load_ubo)
intrin.contents.num_components = 1

View File

@@ -8,8 +8,7 @@ try: import tinygrad.runtime.autogen.llvm as llvm
except (ImportError, FileNotFoundError): llvm = None #type:ignore[assignment]
def deserialize(enc_src, opts):
blobreader = mesa.struct_blob_reader()
mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src))
mesa.blob_reader_init(blobreader:=mesa.struct_blob_reader(), src:=base64.b64decode(enc_src), len(src))
return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader)
class LVPCompiler(Compiler):
@@ -48,16 +47,14 @@ class LVPCompiler(Compiler):
def disassemble(self, lib:bytes): cpu_objdump(lib)
class NAKCompiler(Compiler):
def __init__(self, arch, warps_per_sm, cache_key="nak"):
self.arch, self.warps_per_sm = arch, warps_per_sm
self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm))
def __init__(self, arch, warps, cache_key="nak"):
self.arch, self.warps, self.cc = arch, warps, mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps))
self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents)
mesa.glsl_type_singleton_init_or_ref()
super().__init__(f"compile_{cache_key}_{arch}")
def __del__(self):
mesa.nak_compiler_destroy(self.cc)
mesa.glsl_type_singleton_decref()
def __del__(self): mesa.nak_compiler_destroy(self.cc)
def __reduce__(self): return NAKCompiler, (self.arch, self.warps)
def compile(self, src) -> bytes:
shader = deserialize(src, self.nir_options)
@@ -74,8 +71,6 @@ class NAKCompiler(Compiler):
print(subprocess.check_output(['nvdisasm', "-b", f"SM{self.arch[3:]}", fn]).decode('utf-8'))
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm)
def parse_nak_shader(shader:bytes) -> Tuple[memoryview, int, int, int]:
info = mesa.struct_nak_shader_info.from_buffer(shader)
return (memoryview(shader[ctypes.sizeof(info):]), info.num_gprs, round_up(info.cs.smem_size, 0x80), round_up(info.slm_size, 0x10))