mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 00:55:11 -05:00
working to improve ptx (#3647)
* working to improve ptx * fix compile fail
This commit is contained in:
@@ -35,7 +35,7 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
|
||||
return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
|
||||
|
||||
class PTXCompiler(Compiler):
|
||||
linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], supports_float4=False)
|
||||
linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], supports_float4=False)
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
|
||||
@@ -58,21 +58,27 @@ class CUDACompiler(Compiler):
|
||||
if status != 0: raise RuntimeError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check).decode()}")
|
||||
return _get_bytes(prog, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, check)
|
||||
|
||||
def cuda_disassemble(lib, arch):
|
||||
try:
|
||||
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
||||
with open(fn + ".ptx", "wb") as f: f.write(lib)
|
||||
subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
|
||||
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
|
||||
except Exception as e: print("failed to generate SASS", str(e))
|
||||
|
||||
class CUDAProgram:
|
||||
def __init__(self, device:CUDADevice, name:str, lib:bytes):
|
||||
self.device, self.name, self.lib = device, name, lib
|
||||
if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))]))
|
||||
if DEBUG >= 6:
|
||||
try:
|
||||
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
||||
with open(fn + ".ptx", "wb") as f: f.write(lib)
|
||||
subprocess.run(["ptxas", f"-arch={device.arch}", "-o", fn, fn+".ptx"], check=True)
|
||||
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
|
||||
except Exception as e: print("failed to generate SASS", str(e))
|
||||
if DEBUG >= 6: cuda_disassemble(lib, device.arch)
|
||||
|
||||
if not CUDACPU:
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
self.module = init_c_var(cuda.CUmodule(), lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), lib)))
|
||||
self.module = cuda.CUmodule()
|
||||
status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib)
|
||||
if status != 0:
|
||||
cuda_disassemble(lib, device.arch)
|
||||
raise RuntimeError("module load failed")
|
||||
check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8")))
|
||||
self.prg = prg if not CUDACPU else lib
|
||||
|
||||
|
||||
Reference in New Issue
Block a user