simplify render_kernel

This commit is contained in:
Mesozoic Egg
2024-11-24 15:21:14 +08:00
parent 5d28a202b5
commit 1c8817bea2

View File

@@ -151,11 +151,9 @@ class PTXRenderer(Renderer):
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
return (f"{self.kernel_prefix} {function_name}(\n\t" +
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
"\n}")
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) +
"\n)\n{\n" + '\n'.join([f"\t{k}" for k in kernel]) + "\n}")
def render(self, name:str, uops:List[UOp]) -> str:
kernel:List[str] = []