From 7399f6dad7099c540bfa84aec137d704fe2231d0 Mon Sep 17 00:00:00 2001 From: Oddity <32976445+NotNotOddity@users.noreply.github.com> Date: Sat, 15 Jul 2023 02:36:04 -0500 Subject: [PATCH] display sass for both cuda code and ptx (#1240) * skip nvcc compile target cubin when using PTX * actually we should generate sass for both ptx and cuda code * Fixed formatting, should print the error anyway * ensure subprocess.run throws exception * fixed linting errors and checked before commit this time --- tinygrad/runtime/ops_cuda.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index f37c2299e5..746f02e0cb 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -16,6 +16,7 @@ def pretty_ptx(s): s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives return s +def arch(): return "sm_" + "".join([str(x) for x in pycuda.driver.Context.get_device().compute_capability()]) if getenv("CUDACPU", 0) == 1: import ctypes, ctypes.util @@ -52,18 +53,19 @@ else: class CUDAProgram: def __init__(self, name:str, prg:str, binary=False): - try: - if DEBUG >= 6: - fn = f"{tempfile.gettempdir()}/tinycuda_{hashlib.md5(prg.encode('utf-8')).hexdigest()}" - with open(fn, "wb") as f: - f.write(cuda_compile(prg, target="cubin", no_extern_c=True)) - sass = subprocess.check_output(['nvdisasm', fn]).decode('utf-8') - print(sass) - if not binary: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8') - except cuda.CompileError as e: - if DEBUG >= 3: print("FAILED TO BUILD", prg) - raise e + if not binary: + try: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8') + except cuda.CompileError as e: + if DEBUG >= 3: print("FAILED TO BUILD", prg) + raise e if DEBUG >= 5: print(pretty_ptx(prg)) + if DEBUG >= 6: + try: + fn = f"{tempfile.gettempdir()}\\tinycuda_{hashlib.md5(prg.encode('utf-8')).hexdigest()}" + with open(fn + ".ptx", "wb") as f: f.write(prg.encode('utf-8')) + subprocess.run(["ptxas", f"-arch={arch()}", "-o", fn, fn+".ptx"], check=True) + print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8')) + except Exception as e: print("failed to generate SASS", str(e)) # TODO: name is wrong, so we get it from the ptx using hacks self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])