mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
pretty ptx print on debug 5
This commit is contained in:
@@ -7,6 +7,8 @@ from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.ops_cuda import RawCUDABuffer
|
||||
from tinygrad.codegen.linearizer import LinearizerOptions
|
||||
from tinygrad.renderer.triton import uops_to_triton
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.runtime.ops_cuda import pretty_ptx
|
||||
|
||||
|
||||
class TritonProgram:
|
||||
@@ -22,6 +24,7 @@ class TritonProgram:
|
||||
codeObject = compile(prg, fn, "exec")
|
||||
exec(codeObject, globals())
|
||||
self.program = triton_compile(globals()[name], signature=signature, device_type="cuda", debug=True).asm["ptx"]
|
||||
if DEBUG >= 5: print(pretty_ptx(self.program))
|
||||
self.program = cuda.module_from_buffer(self.program.encode('utf-8')).get_function(self.program.split(".visible .entry ")[1].split("(")[0])
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False) -> Any:
|
||||
|
||||
Reference in New Issue
Block a user