From 0458120cf226c8a98d50c87a7b5d202185092d65 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 3 Sep 2023 09:32:38 -0700 Subject: [PATCH] pretty kernel in cstyle --- tinygrad/renderer/cstyle.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 051158877f..40fe67ab25 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -117,6 +117,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T return f"{prefix}{c[prefix]-1}" r: Dict[UOp, str] = {} + child_count: DefaultDict[UOp, int] = defaultdict(int) + for ru in uops: + for v in ru.vin: + child_count[v] += 1 + for u in uops: uop,dtype,vin,args,_ = u if uop == UOps.LOOP: @@ -166,8 +171,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T raise NotImplementedError(f"WMMA not implemented for {args}") elif uop == UOps.ALU: assert dtype is not None - r[u] = ssa('alu') - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.code_for_op[args](*[r[x] for x in vin])};") + val = lang.code_for_op[args](*[r[x] for x in vin]) + if child_count[u] == 1: r[u] = val + else: + r[u] = ssa('alu') + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};") elif uop == UOps.DEFINE_ACC: assert dtype is not None r[u] = ssa('acc')