From 1c8817bea2154cf70a239a938974c75e77f730de Mon Sep 17 00:00:00 2001 From: Mesozoic Egg Date: Sun, 24 Nov 2024 15:21:14 +0800 Subject: [PATCH] simplify render_kernel --- tinygrad/renderer/ptx.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index cd50fe3300..f4a332e65c 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -151,11 +151,9 @@ class PTXRenderer(Renderer): def render_kernel(self, kernel, function_name, bufs, regs) -> str: kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"] - def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1) return (f"{self.kernel_prefix} {function_name}(\n\t" + - ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" + - '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) + - "\n}") + ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + + "\n)\n{\n" + '\n'.join([f"\t{k}" for k in kernel]) + "\n}") def render(self, name:str, uops:List[UOp]) -> str: kernel:List[str] = []