mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
save more lines
This commit is contained in:
@@ -179,14 +179,6 @@ class NIRRenderer(Renderer):
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: ensure(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
|
||||
])
|
||||
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
mesa.glsl_type_singleton_init_or_ref()
|
||||
|
||||
def __del__(self):
|
||||
try: mesa.glsl_type_singleton_decref()
|
||||
except AttributeError: pass
|
||||
|
||||
@property
|
||||
def nir_options(self): raise NotImplementedError("needs nir_options")
|
||||
def param(self, dtype:DType, sz:int) -> mesa.nir_def: raise NotImplementedError("needs param")
|
||||
@@ -194,6 +186,7 @@ class NIRRenderer(Renderer):
|
||||
self.b = mesa.nir_builder_init_simple_shader(mesa.MESA_SHADER_COMPUTE, mesa.nir_shader_compiler_options.from_buffer_copy(self.nir_options), None)
|
||||
|
||||
def render(self, uops:list[UOp]):
|
||||
mesa.glsl_type_singleton_init_or_ref()
|
||||
self.prerender(uops)
|
||||
for u in [u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]: self.b.shader.contents.info.workgroup_size[int(u.arg[-1])] = u.src[0].arg
|
||||
self.r, self.param_idx, ranges = {}, 0, []
|
||||
@@ -231,14 +224,16 @@ class NIRRenderer(Renderer):
|
||||
mesa.ralloc_free(self.b.shader)
|
||||
ctypes.CDLL(ctypes.util.find_library('c')).free(blob.data)
|
||||
del self.b, self.r
|
||||
mesa.glsl_type_singleton_decref()
|
||||
|
||||
return ret
|
||||
|
||||
class NAKRenderer(NIRRenderer):
|
||||
def __init__(self, dev=None, nir_options=None, device="NV"):
|
||||
device = "NV"
|
||||
|
||||
def __init__(self, dev=None, nir_options=None):
|
||||
if dev: self.dev = dev
|
||||
else: self.__dict__['nir_options'] = nir_options
|
||||
super().__init__(device)
|
||||
|
||||
@classmethod
|
||||
def with_opts(cls, opts): return cls(nir_options=opts)
|
||||
@@ -262,13 +257,12 @@ class NAKRenderer(NIRRenderer):
|
||||
return d(intrin)
|
||||
|
||||
class LVPRenderer(NIRRenderer):
|
||||
device = "CPU"
|
||||
has_local = False
|
||||
has_shared = False
|
||||
global_max = (1, 0, 0)
|
||||
nir_options = mesa.lvp_nir_options
|
||||
|
||||
def __init__(self, device="CPU"): super().__init__(device)
|
||||
|
||||
def param(self, dtype:DType, sz:int) -> mesa.nir_def:
|
||||
intrin = mesa.nir_intrinsic_instr_create(self.b.shader, mesa.nir_intrinsic_load_ubo)
|
||||
intrin.contents.num_components = 1
|
||||
|
||||
@@ -8,8 +8,7 @@ try: import tinygrad.runtime.autogen.llvm as llvm
|
||||
except (ImportError, FileNotFoundError): llvm = None #type:ignore[assignment]
|
||||
|
||||
def deserialize(enc_src, opts):
|
||||
blobreader = mesa.struct_blob_reader()
|
||||
mesa.blob_reader_init(blobreader, src:=base64.b64decode(enc_src), len(src))
|
||||
mesa.blob_reader_init(blobreader:=mesa.struct_blob_reader(), src:=base64.b64decode(enc_src), len(src))
|
||||
return mesa.nir_deserialize(None, ctypes.cast(opts, ctypes.POINTER(mesa.nir_shader_compiler_options)), blobreader)
|
||||
|
||||
class LVPCompiler(Compiler):
|
||||
@@ -48,16 +47,14 @@ class LVPCompiler(Compiler):
|
||||
def disassemble(self, lib:bytes): cpu_objdump(lib)
|
||||
|
||||
class NAKCompiler(Compiler):
|
||||
def __init__(self, arch, warps_per_sm, cache_key="nak"):
|
||||
self.arch, self.warps_per_sm = arch, warps_per_sm
|
||||
self.cc = mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps_per_sm))
|
||||
def __init__(self, arch, warps, cache_key="nak"):
|
||||
self.arch, self.warps, self.cc = arch, warps, mesa.nak_compiler_create(mesa.struct_nv_device_info(sm=int(arch[3:]), max_warps_per_mp=warps))
|
||||
self.nir_options = bytes(mesa.nak_nir_options(self.cc).contents)
|
||||
mesa.glsl_type_singleton_init_or_ref()
|
||||
super().__init__(f"compile_{cache_key}_{arch}")
|
||||
|
||||
def __del__(self):
|
||||
mesa.nak_compiler_destroy(self.cc)
|
||||
mesa.glsl_type_singleton_decref()
|
||||
def __del__(self): mesa.nak_compiler_destroy(self.cc)
|
||||
|
||||
def __reduce__(self): return NAKCompiler, (self.arch, self.warps)
|
||||
|
||||
def compile(self, src) -> bytes:
|
||||
shader = deserialize(src, self.nir_options)
|
||||
@@ -74,8 +71,6 @@ class NAKCompiler(Compiler):
|
||||
print(subprocess.check_output(['nvdisasm', "-b", f"SM{self.arch[3:]}", fn]).decode('utf-8'))
|
||||
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
|
||||
|
||||
def __reduce__(self): return NAKCompiler, (self.arch, self.warps_per_sm)
|
||||
|
||||
def parse_nak_shader(shader:bytes) -> Tuple[memoryview, int, int, int]:
|
||||
info = mesa.struct_nak_shader_info.from_buffer(shader)
|
||||
return (memoryview(shader[ctypes.sizeof(info):]), info.num_gprs, round_up(info.cs.smem_size, 0x80), round_up(info.slm_size, 0x10))
|
||||
|
||||
Reference in New Issue
Block a user