clean up display name

This commit is contained in:
George Hotz
2023-03-22 18:32:05 -07:00
parent b12b60af20
commit 2e18469fd4
3 changed files with 9 additions and 10 deletions

View File

@@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes
from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
from tinygrad.lazy import LazyBuffer
@@ -182,7 +182,7 @@ class CStyleCodegen(Linearizer):
# for renaming
kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int)
kernel_name_cache: Final[Dict[str, str]] = {}
kernel_name_cache: Final[Dict[str, Tuple[str, str]]] = {}
def codegen(self):
self.process()
@@ -207,13 +207,12 @@ class CStyleCodegen(Linearizer):
for i,s in enumerate(local_size): global_size[i] *= s
# painfully name the function something unique
function_name = self.function_name
if prg in CStyleCodegen.kernel_name_cache: function_name = CStyleCodegen.kernel_name_cache[prg]
if prg in CStyleCodegen.kernel_name_cache: function_name, display_name = CStyleCodegen.kernel_name_cache[prg]
else:
CStyleCodegen.kernel_cnt[function_name] += 1
if CStyleCodegen.kernel_cnt[function_name] > 1: function_name = f"{function_name}{'n'+str(CStyleCodegen.kernel_cnt[function_name]-1)}"
CStyleCodegen.kernel_name_cache[prg] = function_name
CStyleCodegen.kernel_cnt[self.function_name] += 1
suffix = f"{'n'+str(CStyleCodegen.kernel_cnt[self.function_name]-1)}" if CStyleCodegen.kernel_cnt[self.function_name] > 1 else ""
CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'black', bright=True)
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
global_size[::-1] if len(global_size) else [1], local_size[::-1] if len(local_size) else None,
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name)
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name)