mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
simplify render_kernel
This commit is contained in:
@@ -151,11 +151,9 @@ class PTXRenderer(Renderer):
|
||||
|
||||
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
|
||||
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
|
||||
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
||||
return (f"{self.kernel_prefix} {function_name}(\n\t" +
|
||||
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
|
||||
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
|
||||
"\n}")
|
||||
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) +
|
||||
"\n)\n{\n" + '\n'.join([f"\t{k}" for k in kernel]) + "\n}")
|
||||
|
||||
def render(self, name:str, uops:List[UOp]) -> str:
|
||||
kernel:List[str] = []
|
||||
|
||||
Reference in New Issue
Block a user