mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-27 07:48:07 -05:00
pretty kernel in cstyle
This commit is contained in:
@@ -117,6 +117,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
return f"{prefix}{c[prefix]-1}"
|
||||
r: Dict[UOp, str] = {}
|
||||
|
||||
child_count: DefaultDict[UOp, int] = defaultdict(int)
|
||||
for ru in uops:
|
||||
for v in ru.vin:
|
||||
child_count[v] += 1
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == UOps.LOOP:
|
||||
@@ -166,8 +171,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
raise NotImplementedError(f"WMMA not implemented for {args}")
|
||||
elif uop == UOps.ALU:
|
||||
assert dtype is not None
|
||||
r[u] = ssa('alu')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.code_for_op[args](*[r[x] for x in vin])};")
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin])
|
||||
if child_count[u] == 1: r[u] = val
|
||||
else:
|
||||
r[u] = ssa('alu')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
assert dtype is not None
|
||||
r[u] = ssa('acc')
|
||||
|
||||
Reference in New Issue
Block a user