remove the stupid register class (#721)

* remove the stupid register class

* touchups

* colorful display name
This commit is contained in:
George Hotz
2023-03-20 15:45:12 -07:00
committed by GitHub
parent 30b795874a
commit 06abbbfe7c
4 changed files with 126 additions and 133 deletions

View File

@@ -1,4 +1,4 @@
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
@@ -56,7 +56,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})"
}
def uops_to_cstyle(uops:List[UOp], bufs:List[LazyBuffer], bufnames:List[str], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
def group_float4(grp:List[str]) -> str:
if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0]
else: return f"{lang.float4}({','.join(g for g in grp)})"
@@ -67,6 +67,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[LazyBuffer], bufnames:List[str], la
local_size = []
pend_close = None
bufnames = ["temp" if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)]
depth = 0
def kk(s): kernel.append(" "*depth+s)
@@ -200,7 +202,7 @@ class CStyleCodegen(Linearizer):
self.hand_coded_optimizations()
self.linearize()
prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, [x.name for x in self.registers], self.lang)
prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, self.lang)
# if we have local_sizes, we have to correct the global_size
for i,s in enumerate(local_size): global_size[i] *= s
@@ -215,4 +217,4 @@ class CStyleCodegen(Linearizer):
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
global_size[::-1] if len(global_size) else [1], local_size[::-1] if len(local_size) else None,
op_estimate=self.info.flops, mem_estimate=self.mem_estimate)
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name)