diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 285f47aff9..ad8c3ba367 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -70,6 +70,7 @@ class LocalBuffer(NamedTuple): class LinearizerOptions(NamedTuple): device: str = "" + suffix: str = "" # TODO: make this generic with a list of supported types supports_float4: bool = True has_local: bool = True diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 7e9679fc88..95a738e634 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -96,7 +96,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz beam_pool = None def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer: global beam_pool - key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device} + key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix} if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1: ret = lin.copy() for o in val[len(lin.applied_opts):]: ret.apply_opt(o) @@ -105,7 +105,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea beam: List[Tuple[Linearizer, float]] = [] seen_libs = set() - default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HIP", "HSA"} else 0, getenv("BEAM_MIN_PROGRESS",0) + default_parallel, min_progress_micros = 1 if lin.opts.device in {"CUDA", "HIP", "HSA"} else 0, getenv("BEAM_MIN_PROGRESS",0.01) if beam_pool is None and getenv("PARALLEL", default_parallel): beam_pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker) try: @@ -155,7 +155,7 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff return ret[1] def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501 - key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device} # noqa: E501 + key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix} # noqa: E501 if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) dev = Device[lin.opts.device] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 6cff505237..47bdff2d50 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -129,7 +129,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db"))) CACHELEVEL = getenv("CACHELEVEL", 2) -VERSION = 12 +VERSION = 13 _db_connection = None def db_connection(): global _db_connection diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index fdd008e433..64240cf0d1 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -37,7 +37,6 @@ class AssemblyLanguage(NamedTuple): def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError() def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: - local_size: List[int] = [] kernel:List[str] = [] bufs = [] @@ -131,12 +130,8 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype])) elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};") elif uop == UOps.SPECIAL: - if args[1][0] == "i": kk(f"mov.u32 %{args[1]}, {lang.gid[args[0]]};", f"mov.u32 {(gdim:=ssa(None,'tmp','u32'))}, {lang.gdim[args[0]]};", - f"mov.u32 {(lid:=ssa(None,'tmp','u32'))}, {lang.lid[args[0]]};", - f"mad.lo.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, %{args[1]}, {gdim}, {lid};") - else: kk(f"mov.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};") - kk(*lang.render_cast(f"%{args[1]}", tmp, dtypes.uint, dtypes.int)) - if args[1][0] == "l": local_size.append(args[2]) + assert args[1][0] != "i", "idx not supported" + kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};") r[u] = "%" + args[1] kernel = [f".reg .u32 %{args[1]};"] + kernel elif uop == UOps.CONST: r[u] = const(args, dtype, mov=True) @@ -157,7 +152,10 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: elif uop == UOps.CAST: assert vin[0].dtype is not None cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u) - elif uop == UOps.DEFINE_LOCAL: kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype)) + elif uop == UOps.DEFINE_LOCAL: + # TODO: we should sum these, and fetch 0xC000 from somewhere + assert args[1]*dtype.itemsize <= 0xC000, "too large local" + kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype)) elif uop is UOps.DEFINE_VAR: bufs.append((args.expr, dtype)) r[u] = f"%{args.expr}" diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index e0164e4d7c..e966d54551 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -35,7 +35,7 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes: return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value) class PTXCompiler(Compiler): - linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], supports_float4=False) + linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], supports_float4=False) def __init__(self, arch:str): self.arch = arch PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80) @@ -58,21 +58,27 @@ class CUDACompiler(Compiler): if status != 0: raise RuntimeError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check).decode()}") return _get_bytes(prog, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, check) +def cuda_disassemble(lib, arch): + try: + fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix() + with open(fn + ".ptx", "wb") as f: f.write(lib) + subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True) + print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8')) + except Exception as e: print("failed to generate SASS", str(e)) + class CUDAProgram: def __init__(self, device:CUDADevice, name:str, lib:bytes): self.device, self.name, self.lib = device, name, lib if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))])) - if DEBUG >= 6: - try: - fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix() - with open(fn + ".ptx", "wb") as f: f.write(lib) - subprocess.run(["ptxas", f"-arch={device.arch}", "-o", fn, fn+".ptx"], check=True) - print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8')) - except Exception as e: print("failed to generate SASS", str(e)) + if DEBUG >= 6: cuda_disassemble(lib, device.arch) if not CUDACPU: check(cuda.cuCtxSetCurrent(self.device.context)) - self.module = init_c_var(cuda.CUmodule(), lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), lib))) + self.module = cuda.CUmodule() + status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib) + if status != 0: + cuda_disassemble(lib, device.arch) + raise RuntimeError("module load failed") check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8"))) self.prg = prg if not CUDACPU else lib