pretty ptx print on debug 5

This commit is contained in:
Szymon Ożóg
2023-08-22 08:32:36 +02:00
parent 4e18f4e7ae
commit d3f370d69d

View File

@@ -7,6 +7,8 @@ from tinygrad.ops import Compiled
from tinygrad.runtime.ops_cuda import RawCUDABuffer
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.triton import uops_to_triton
from tinygrad.helpers import DEBUG
from tinygrad.runtime.ops_cuda import pretty_ptx
class TritonProgram:
@@ -22,6 +24,7 @@ class TritonProgram:
codeObject = compile(prg, fn, "exec")
exec(codeObject, globals())
self.program = triton_compile(globals()[name], signature=signature, device_type="cuda", debug=True).asm["ptx"]
if DEBUG >= 5: print(pretty_ptx(self.program))
self.program = cuda.module_from_buffer(self.program.encode('utf-8')).get_function(self.program.split(".visible .entry ")[1].split("(")[0])
def __call__(self, global_size, local_size, *args, wait=False) -> Any: