From 9f1a54aceee8aa7639f16c77834c449bb2c7db5b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 3 Sep 2023 10:21:02 -0700 Subject: [PATCH] pretty kernel in cstyle (#1746) * pretty kernel in cstyle * fix mem estimate * that made it slower * Revert "that made it slower" This reverts commit faa4cd0187b1d17ddbb6ce3ce0e842904a9001b4. --- tinygrad/codegen/kernel.py | 4 ++-- tinygrad/renderer/cstyle.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index b1c96c8beb..194c9a37ba 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -2,7 +2,7 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict import itertools from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps from tinygrad.lazy import LazyBuffer -from tinygrad.helpers import dedup, dtypes, colored, prod, ImageDType, DType +from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType from tinygrad.runtime.lib import buf_is_kernel_arg from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape @@ -42,7 +42,7 @@ class Kernel: # fetch lazyop info self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast)) - self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None) + self.mem_estimate: int = sum(x.dtype.itemsize*x.size for x in self.arg_bufs.keys()) # there's only allowed to be one reduceop reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps] diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 051158877f..40fe67ab25 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -117,6 +117,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T return f"{prefix}{c[prefix]-1}" r: Dict[UOp, str] = {} + child_count: DefaultDict[UOp, int] = defaultdict(int) + for ru in uops: + for v in ru.vin: + child_count[v] += 1 + for u in uops: uop,dtype,vin,args,_ = u if uop == UOps.LOOP: @@ -166,8 +171,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T raise NotImplementedError(f"WMMA not implemented for {args}") elif uop == UOps.ALU: assert dtype is not None - r[u] = ssa('alu') - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.code_for_op[args](*[r[x] for x in vin])};") + val = lang.code_for_op[args](*[r[x] for x in vin]) + if child_count[u] == 1: r[u] = val + else: + r[u] = ssa('alu') + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};") elif uop == UOps.DEFINE_ACC: assert dtype is not None r[u] = ssa('acc')