mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix ops print bug
This commit is contained in:
@@ -319,7 +319,7 @@ class GPUCodegen(ASTKernel):
|
||||
|
||||
# concat kernel into prg
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype()+self.lang.buffer_suffix for i,x in enumerate(self.bufs)]
|
||||
self.prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
|
||||
prg = ' '.join(list(self.prekernel) + [f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
|
||||
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] +
|
||||
[") {\n"] + self.kernel)
|
||||
|
||||
@@ -327,16 +327,16 @@ class GPUCodegen(ASTKernel):
|
||||
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
|
||||
|
||||
# painfully name the function
|
||||
if self.prg in GPUCodegen.kernel_name_cache:
|
||||
function_name = GPUCodegen.kernel_name_cache[self.prg]
|
||||
if prg in GPUCodegen.kernel_name_cache:
|
||||
function_name = GPUCodegen.kernel_name_cache[prg]
|
||||
else:
|
||||
GPUCodegen.kernel_cnt[function_name] += 1
|
||||
if GPUCodegen.kernel_cnt[function_name]:
|
||||
function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}"
|
||||
GPUCodegen.kernel_name_cache[self.prg] = function_name
|
||||
GPUCodegen.kernel_name_cache[prg] = function_name
|
||||
|
||||
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
|
||||
return ASTRunner(function_name, self.prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
|
||||
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
|
||||
self.output_shape[::-1] if len(self.output_shape) > 0 else [1],
|
||||
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
|
||||
op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))
|
||||
|
||||
@@ -112,7 +112,7 @@ class ASTRunner:
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 1:
|
||||
print(f"**** {GlobalCounters.kernel_count:4d} {self.name:20s} args {len(bufs)-len(self.bufs_to_delete):5d} kernels {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if DEBUG <= 1 else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS)"))
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS)"))
|
||||
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
|
||||
return et
|
||||
|
||||
|
||||
Reference in New Issue
Block a user