mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add ptx formatter + syntax highlighter (#1128)
This commit is contained in:
@@ -1,13 +1,24 @@
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
import time
|
||||
import re
|
||||
import numpy as np
|
||||
from pycuda.compiler import compile as cuda_compile # type: ignore
|
||||
from tinygrad.helpers import DEBUG, getenv, fromimport
|
||||
from tinygrad.helpers import DEBUG, getenv, fromimport, colored
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer
|
||||
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
|
||||
|
||||
def pretty_ptx(s):
|
||||
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
|
||||
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
|
||||
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
|
||||
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
|
||||
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
|
||||
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
|
||||
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
||||
return s
|
||||
|
||||
if getenv("CUDACPU", 0) == 1:
|
||||
import ctypes, ctypes.util
|
||||
lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
|
||||
@@ -53,7 +64,7 @@ class CUDAProgram:
|
||||
except cuda.CompileError as e:
|
||||
if DEBUG >= 3: print("FAILED TO BUILD", prg)
|
||||
raise e
|
||||
if DEBUG >= 5: print(prg)
|
||||
if DEBUG >= 5: print(pretty_ptx(prg))
|
||||
# TODO: name is wrong, so we get it from the ptx using hacks
|
||||
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user