diff --git a/tinygrad/runtime/ops_triton.py b/tinygrad/runtime/ops_triton.py index 435e83e400..01ad141ca9 100644 --- a/tinygrad/runtime/ops_triton.py +++ b/tinygrad/runtime/ops_triton.py @@ -7,6 +7,8 @@ from tinygrad.ops import Compiled from tinygrad.runtime.ops_cuda import RawCUDABuffer from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.triton import uops_to_triton +from tinygrad.helpers import DEBUG +from tinygrad.runtime.ops_cuda import pretty_ptx class TritonProgram: @@ -22,6 +24,7 @@ class TritonProgram: codeObject = compile(prg, fn, "exec") exec(codeObject, globals()) self.program = triton_compile(globals()[name], signature=signature, device_type="cuda", debug=True).asm["ptx"] + if DEBUG >= 5: print(pretty_ptx(self.program)) self.program = cuda.module_from_buffer(self.program.encode('utf-8')).get_function(self.program.split(".visible .entry ")[1].split("(")[0]) def __call__(self, global_size, local_size, *args, wait=False) -> Any: