clean up display name

This commit is contained in:
George Hotz
2023-03-22 18:32:05 -07:00
parent b12b60af20
commit 2e18469fd4
3 changed files with 9 additions and 10 deletions

View File

@@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes
from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
from tinygrad.lazy import LazyBuffer
@@ -182,7 +182,7 @@ class CStyleCodegen(Linearizer):
# for renaming
kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int)
kernel_name_cache: Final[Dict[str, str]] = {}
kernel_name_cache: Final[Dict[str, Tuple[str, str]]] = {}
def codegen(self):
self.process()
@@ -207,13 +207,12 @@ class CStyleCodegen(Linearizer):
for i,s in enumerate(local_size): global_size[i] *= s
# painfully name the function something unique
function_name = self.function_name
if prg in CStyleCodegen.kernel_name_cache: function_name = CStyleCodegen.kernel_name_cache[prg]
if prg in CStyleCodegen.kernel_name_cache: function_name, display_name = CStyleCodegen.kernel_name_cache[prg]
else:
CStyleCodegen.kernel_cnt[function_name] += 1
if CStyleCodegen.kernel_cnt[function_name] > 1: function_name = f"{function_name}{'n'+str(CStyleCodegen.kernel_cnt[function_name]-1)}"
CStyleCodegen.kernel_name_cache[prg] = function_name
CStyleCodegen.kernel_cnt[self.function_name] += 1
suffix = f"{'n'+str(CStyleCodegen.kernel_cnt[self.function_name]-1)}" if CStyleCodegen.kernel_cnt[self.function_name] > 1 else ""
CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'black', bright=True)
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
global_size[::-1] if len(global_size) else [1], local_size[::-1] if len(local_size) else None,
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name)
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name)

View File

@@ -149,7 +149,7 @@ class Linearizer:
# kernel name (before late upcast)
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape])
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'black', bright=True).join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) + " " * (23-len(self.function_name))
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'black', bright=True).join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
# parse AST
loaded_buffers = {}

View File

@@ -93,7 +93,7 @@ class ASTRunner:
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs)
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(f"*** {GlobalCounters.kernel_count:4d} {self.display_name if self.display_name is not None else self.name:20s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
print(f"*** {GlobalCounters.kernel_count:4d} {(self.display_name+' '*(23-len(self.name))) if self.display_name is not None else self.name:23s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):6.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += self.op_estimate