pretty kernel in cstyle (#1746)

* pretty kernel in cstyle

* fix mem estimate

* that made it slower

* Revert "that made it slower"

This reverts commit faa4cd0187.
This commit is contained in:
George Hotz
2023-09-03 10:21:02 -07:00
committed by GitHub
parent e910e0e62c
commit 9f1a54acee
2 changed files with 12 additions and 4 deletions

View File

@@ -2,7 +2,7 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict
import itertools
from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import dedup, dtypes, colored, prod, ImageDType, DType
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType
from tinygrad.runtime.lib import buf_is_kernel_arg
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
@@ -42,7 +42,7 @@ class Kernel:
# fetch lazyop info
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)
self.mem_estimate: int = sum(x.dtype.itemsize*x.size for x in self.arg_bufs.keys())
# there's only allowed to be one reduceop
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]

View File

@@ -117,6 +117,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
return f"{prefix}{c[prefix]-1}"
r: Dict[UOp, str] = {}
child_count: DefaultDict[UOp, int] = defaultdict(int)
for ru in uops:
for v in ru.vin:
child_count[v] += 1
for u in uops:
uop,dtype,vin,args,_ = u
if uop == UOps.LOOP:
@@ -166,8 +171,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
raise NotImplementedError(f"WMMA not implemented for {args}")
elif uop == UOps.ALU:
assert dtype is not None
r[u] = ssa('alu')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.code_for_op[args](*[r[x] for x in vin])};")
val = lang.code_for_op[args](*[r[x] for x in vin])
if child_count[u] == 1: r[u] = val
else:
r[u] = ssa('alu')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
elif uop == UOps.DEFINE_ACC:
assert dtype is not None
r[u] = ssa('acc')