add ptx formatter + syntax highlighter (#1128)

This commit is contained in:
cloud11665
2023-07-06 02:56:09 +02:00
committed by GitHub
parent d1356cac27
commit b7369ffcff

View File

@@ -1,13 +1,24 @@
import subprocess
from typing import Optional
import time
import re
import numpy as np
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, fromimport
from tinygrad.helpers import DEBUG, getenv, fromimport, colored
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
def pretty_ptx(s):
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
return s
if getenv("CUDACPU", 0) == 1:
import ctypes, ctypes.util
lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
@@ -53,7 +64,7 @@ class CUDAProgram:
except cuda.CompileError as e:
if DEBUG >= 3: print("FAILED TO BUILD", prg)
raise e
if DEBUG >= 5: print(prg)
if DEBUG >= 5: print(pretty_ptx(prg))
# TODO: name is wrong, so we get it from the ptx using hacks
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])