mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
remove the stupid register class (#721)
* remove the stupid register class * touchups * colorful display name
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set
|
||||
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
@@ -56,7 +56,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})"
|
||||
}
|
||||
|
||||
def uops_to_cstyle(uops:List[UOp], bufs:List[LazyBuffer], bufnames:List[str], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
||||
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
||||
def group_float4(grp:List[str]) -> str:
|
||||
if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0]
|
||||
else: return f"{lang.float4}({','.join(g for g in grp)})"
|
||||
@@ -67,6 +67,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[LazyBuffer], bufnames:List[str], la
|
||||
local_size = []
|
||||
pend_close = None
|
||||
|
||||
bufnames = ["temp" if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)]
|
||||
|
||||
depth = 0
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
@@ -200,7 +202,7 @@ class CStyleCodegen(Linearizer):
|
||||
self.hand_coded_optimizations()
|
||||
self.linearize()
|
||||
|
||||
prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, [x.name for x in self.registers], self.lang)
|
||||
prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, self.lang)
|
||||
|
||||
# if we have local_sizes, we have to correct the global_size
|
||||
for i,s in enumerate(local_size): global_size[i] *= s
|
||||
@@ -215,4 +217,4 @@ class CStyleCodegen(Linearizer):
|
||||
|
||||
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
|
||||
global_size[::-1] if len(global_size) else [1], local_size[::-1] if len(local_size) else None,
|
||||
op_estimate=self.info.flops, mem_estimate=self.mem_estimate)
|
||||
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name)
|
||||
|
||||
Reference in New Issue
Block a user