mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
pretty kernel in cstyle (#1746)
* pretty kernel in cstyle
* fix mem estimate
* that made it slower
* Revert "that made it slower"
This reverts commit faa4cd0187.
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict
|
||||
import itertools
|
||||
from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import dedup, dtypes, colored, prod, ImageDType, DType
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType
|
||||
from tinygrad.runtime.lib import buf_is_kernel_arg
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
|
||||
@@ -42,7 +42,7 @@ class Kernel:
|
||||
|
||||
# fetch lazyop info
|
||||
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*x.size for x in self.arg_bufs.keys())
|
||||
|
||||
# there's only allowed to be one reduceop
|
||||
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
|
||||
|
||||
@@ -117,6 +117,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
return f"{prefix}{c[prefix]-1}"
|
||||
r: Dict[UOp, str] = {}
|
||||
|
||||
child_count: DefaultDict[UOp, int] = defaultdict(int)
|
||||
for ru in uops:
|
||||
for v in ru.vin:
|
||||
child_count[v] += 1
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == UOps.LOOP:
|
||||
@@ -166,8 +171,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
raise NotImplementedError(f"WMMA not implemented for {args}")
|
||||
elif uop == UOps.ALU:
|
||||
assert dtype is not None
|
||||
r[u] = ssa('alu')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.code_for_op[args](*[r[x] for x in vin])};")
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin])
|
||||
if child_count[u] == 1: r[u] = val
|
||||
else:
|
||||
r[u] = ssa('alu')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
assert dtype is not None
|
||||
r[u] = ssa('acc')
|
||||
|
||||
Reference in New Issue
Block a user