From d3f370d69d8d0a519ccc08282bd48fbcf91338e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= Date: Tue, 22 Aug 2023 08:32:36 +0200 Subject: [PATCH] pretty ptx print on debug 5 --- tinygrad/runtime/ops_triton.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tinygrad/runtime/ops_triton.py b/tinygrad/runtime/ops_triton.py index 435e83e400..01ad141ca9 100644 --- a/tinygrad/runtime/ops_triton.py +++ b/tinygrad/runtime/ops_triton.py @@ -7,6 +7,8 @@ from tinygrad.ops import Compiled from tinygrad.runtime.ops_cuda import RawCUDABuffer from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.triton import uops_to_triton +from tinygrad.helpers import DEBUG +from tinygrad.runtime.ops_cuda import pretty_ptx class TritonProgram: @@ -22,6 +24,7 @@ class TritonProgram: codeObject = compile(prg, fn, "exec") exec(codeObject, globals()) self.program = triton_compile(globals()[name], signature=signature, device_type="cuda", debug=True).asm["ptx"] + if DEBUG >= 5: print(pretty_ptx(self.program)) self.program = cuda.module_from_buffer(self.program.encode('utf-8')).get_function(self.program.split(".visible .entry ")[1].split("(")[0]) def __call__(self, global_size, local_size, *args, wait=False) -> Any: