display sass for both cuda code and ptx (#1240)

* skip nvcc compile target cubin when using PTX

* actually we should generate sass for both ptx and cuda code

* Fixed formatting, should print the error anyway

* ensure subprocess.run throws exception

* fixed linting errors and checked before commit this time
This commit is contained in:
Oddity
2023-07-15 02:36:04 -05:00
committed by GitHub
parent 264d467f2b
commit 7399f6dad7

View File

@@ -16,6 +16,7 @@ def pretty_ptx(s):
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
return s
def arch(): return "sm_" + "".join([str(x) for x in pycuda.driver.Context.get_device().compute_capability()])
if getenv("CUDACPU", 0) == 1:
import ctypes, ctypes.util
@@ -52,18 +53,19 @@ else:
class CUDAProgram:
def __init__(self, name:str, prg:str, binary=False):
try:
if DEBUG >= 6:
fn = f"{tempfile.gettempdir()}/tinycuda_{hashlib.md5(prg.encode('utf-8')).hexdigest()}"
with open(fn, "wb") as f:
f.write(cuda_compile(prg, target="cubin", no_extern_c=True))
sass = subprocess.check_output(['nvdisasm', fn]).decode('utf-8')
print(sass)
if not binary: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
except cuda.CompileError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
if not binary:
try: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
except cuda.CompileError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
if DEBUG >= 5: print(pretty_ptx(prg))
if DEBUG >= 6:
try:
fn = f"{tempfile.gettempdir()}\\tinycuda_{hashlib.md5(prg.encode('utf-8')).hexdigest()}"
with open(fn + ".ptx", "wb") as f: f.write(prg.encode('utf-8'))
subprocess.run(["ptxas", f"-arch={arch()}", "-o", fn, fn+".ptx"], check=True)
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
except Exception as e: print("failed to generate SASS", str(e))
# TODO: name is wrong, so we get it from the ptx using hacks
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])