working to improve ptx (#3647)

* working to improve ptx

* fix compile fail
This commit is contained in:
George Hotz
2024-03-07 12:39:31 -08:00
committed by GitHub
parent 1853ec9a02
commit 6e50582e62
5 changed files with 26 additions and 21 deletions

View File

@@ -70,6 +70,7 @@ class LocalBuffer(NamedTuple):
class LinearizerOptions(NamedTuple):
device: str = ""
suffix: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
has_local: bool = True

View File

@@ -96,7 +96,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
beam_pool = None
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
global beam_pool
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device}
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
@@ -105,7 +105,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
beam: List[Tuple[Linearizer, float]] = []
seen_libs = set()
default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HIP", "HSA"} else 0, getenv("BEAM_MIN_PROGRESS",0)
default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HIP", "HSA"} else 0, getenv("BEAM_MIN_PROGRESS",0.01)
if beam_pool is None and getenv("PARALLEL", default_parallel): beam_pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker)
try:
@@ -155,7 +155,7 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
return ret[1]
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device} # noqa: E501
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix} # noqa: E501
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
dev = Device[lin.opts.device]

View File

@@ -129,7 +129,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 12
VERSION = 13
_db_connection = None
def db_connection():
global _db_connection

View File

@@ -37,7 +37,6 @@ class AssemblyLanguage(NamedTuple):
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
local_size: List[int] = []
kernel:List[str] = []
bufs = []
@@ -131,12 +130,8 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
elif uop == UOps.SPECIAL:
if args[1][0] == "i": kk(f"mov.u32 %{args[1]}, {lang.gid[args[0]]};", f"mov.u32 {(gdim:=ssa(None,'tmp','u32'))}, {lang.gdim[args[0]]};",
f"mov.u32 {(lid:=ssa(None,'tmp','u32'))}, {lang.lid[args[0]]};",
f"mad.lo.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, %{args[1]}, {gdim}, {lid};")
else: kk(f"mov.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
kk(*lang.render_cast(f"%{args[1]}", tmp, dtypes.uint, dtypes.int))
if args[1][0] == "l": local_size.append(args[2])
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
r[u] = "%" + args[1]
kernel = [f".reg .u32 %{args[1]};"] + kernel
elif uop == UOps.CONST: r[u] = const(args, dtype, mov=True)
@@ -157,7 +152,10 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
elif uop == UOps.CAST:
assert vin[0].dtype is not None
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u)
elif uop == UOps.DEFINE_LOCAL: kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop == UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"

View File

@@ -35,7 +35,7 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
class PTXCompiler(Compiler):
linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], supports_float4=False)
linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], supports_float4=False)
def __init__(self, arch:str):
self.arch = arch
PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
@@ -58,21 +58,27 @@ class CUDACompiler(Compiler):
if status != 0: raise RuntimeError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check).decode()}")
return _get_bytes(prog, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, check)
def cuda_disassemble(lib, arch):
try:
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
with open(fn + ".ptx", "wb") as f: f.write(lib)
subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
except Exception as e: print("failed to generate SASS", str(e))
class CUDAProgram:
def __init__(self, device:CUDADevice, name:str, lib:bytes):
self.device, self.name, self.lib = device, name, lib
if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))]))
if DEBUG >= 6:
try:
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
with open(fn + ".ptx", "wb") as f: f.write(lib)
subprocess.run(["ptxas", f"-arch={device.arch}", "-o", fn, fn+".ptx"], check=True)
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
except Exception as e: print("failed to generate SASS", str(e))
if DEBUG >= 6: cuda_disassemble(lib, device.arch)
if not CUDACPU:
check(cuda.cuCtxSetCurrent(self.device.context))
self.module = init_c_var(cuda.CUmodule(), lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), lib)))
self.module = cuda.CUmodule()
status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib)
if status != 0:
cuda_disassemble(lib, device.arch)
raise RuntimeError("module load failed")
check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8")))
self.prg = prg if not CUDACPU else lib