fix ops print bug

This commit is contained in:
George Hotz
2023-03-02 10:33:03 -08:00
parent 0335cb86b9
commit dc88ad3342
2 changed files with 6 additions and 6 deletions

View File

@@ -319,7 +319,7 @@ class GPUCodegen(ASTKernel):
# concat kernel into prg
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype()+self.lang.buffer_suffix for i,x in enumerate(self.bufs)]
self.prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] +
[") {\n"] + self.kernel)
@@ -327,16 +327,16 @@ class GPUCodegen(ASTKernel):
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
# painfully name the function
if self.prg in GPUCodegen.kernel_name_cache:
function_name = GPUCodegen.kernel_name_cache[self.prg]
if prg in GPUCodegen.kernel_name_cache:
function_name = GPUCodegen.kernel_name_cache[prg]
else:
GPUCodegen.kernel_cnt[function_name] += 1
if GPUCodegen.kernel_cnt[function_name]:
function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}"
GPUCodegen.kernel_name_cache[self.prg] = function_name
GPUCodegen.kernel_name_cache[prg] = function_name
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
return ASTRunner(function_name, self.prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
self.output_shape[::-1] if len(self.output_shape) > 0 else [1],
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))

View File

@@ -112,7 +112,7 @@ class ASTRunner:
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 1:
print(f"**** {GlobalCounters.kernel_count:4d} {self.name:20s} args {len(bufs)-len(self.bufs_to_delete):5d} kernels {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if DEBUG <= 1 else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS)"))
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS)"))
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
return et