mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
display sass for both cuda code and ptx (#1240)
* skip nvcc compile target cubin when using PTX * actually we should generate sass for both ptx and cuda code * Fixed formatting, should print the error anyway * ensure subprocess.run throws exception * fixed linting errors and checked before commit this time
This commit is contained in:
@@ -16,6 +16,7 @@ def pretty_ptx(s):
|
||||
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
|
||||
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
||||
return s
|
||||
def arch(): return "sm_" + "".join([str(x) for x in pycuda.driver.Context.get_device().compute_capability()])
|
||||
|
||||
if getenv("CUDACPU", 0) == 1:
|
||||
import ctypes, ctypes.util
|
||||
@@ -52,18 +53,19 @@ else:
|
||||
|
||||
class CUDAProgram:
|
||||
def __init__(self, name:str, prg:str, binary=False):
|
||||
try:
|
||||
if DEBUG >= 6:
|
||||
fn = f"{tempfile.gettempdir()}/tinycuda_{hashlib.md5(prg.encode('utf-8')).hexdigest()}"
|
||||
with open(fn, "wb") as f:
|
||||
f.write(cuda_compile(prg, target="cubin", no_extern_c=True))
|
||||
sass = subprocess.check_output(['nvdisasm', fn]).decode('utf-8')
|
||||
print(sass)
|
||||
if not binary: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
|
||||
except cuda.CompileError as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
if not binary:
|
||||
try: prg = cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']).decode('utf-8')
|
||||
except cuda.CompileError as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
if DEBUG >= 5: print(pretty_ptx(prg))
|
||||
if DEBUG >= 6:
|
||||
try:
|
||||
fn = f"{tempfile.gettempdir()}\\tinycuda_{hashlib.md5(prg.encode('utf-8')).hexdigest()}"
|
||||
with open(fn + ".ptx", "wb") as f: f.write(prg.encode('utf-8'))
|
||||
subprocess.run(["ptxas", f"-arch={arch()}", "-o", fn, fn+".ptx"], check=True)
|
||||
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
|
||||
except Exception as e: print("failed to generate SASS", str(e))
|
||||
# TODO: name is wrong, so we get it from the ptx using hacks
|
||||
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user