pretty kernel in cstyle

This commit is contained in:
George Hotz
2023-09-03 09:32:38 -07:00
parent e910e0e62c
commit 0458120cf2

View File

@@ -117,6 +117,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
return f"{prefix}{c[prefix]-1}"
r: Dict[UOp, str] = {}
child_count: DefaultDict[UOp, int] = defaultdict(int)
for ru in uops:
for v in ru.vin:
child_count[v] += 1
for u in uops:
uop,dtype,vin,args,_ = u
if uop == UOps.LOOP:
@@ -166,8 +171,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
raise NotImplementedError(f"WMMA not implemented for {args}")
elif uop == UOps.ALU:
assert dtype is not None
r[u] = ssa('alu')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.code_for_op[args](*[r[x] for x in vin])};")
val = lang.code_for_op[args](*[r[x] for x in vin])
if child_count[u] == 1: r[u] = val
else:
r[u] = ssa('alu')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
elif uop == UOps.DEFINE_ACC:
assert dtype is not None
r[u] = ssa('acc')