diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 549aca03e0..1508ccf527 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -99,18 +99,18 @@ class CStyleLanguage(NamedTuple): def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str: local_size: List[int] = [] - kernel,prekernel = [],[] + kernel,prekernel,bufs = [],[],[] #pend_close = None - bufs = [] depth = 1 def kk(s): kernel.append(" "*depth+s) c: DefaultDict[str, int] = defaultdict(int) - def ssa(prefix="t"): - nonlocal c - c[prefix] += 1 - return f"{prefix}{c[prefix]-1}" r: Dict[UOp, str] = {} + def ssa(u, prefix="t"): + nonlocal c, r + c[prefix] += 1 + r[u]=f"{prefix}{c[prefix]-1}" + return r[u] child_count: DefaultDict[UOp, int] = defaultdict(int) for ru in uops: @@ -120,8 +120,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st for u in uops: uop,dtype,vin,args,_ = u if uop == UOps.LOOP: - r[u] = ssa('ridx') - kk(lang.render_for(r[u], r[vin[0]], r[vin[1]])) + kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]])) depth += 1 elif uop == UOps.BARRIER: kk(lang.barrier) @@ -158,12 +157,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st if child_count[u] <= 1 or dtypes.is_int(dtype): # fix index rendering issue r[u] = val else: - r[u] = ssa('alu') - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};") + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};") elif uop == UOps.DEFINE_ACC: assert dtype is not None - r[u] = ssa('acc') - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.render_const(args, dtype)};") + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};") elif uop == UOps.SPECIAL: xid = lang.gid if args[1].startswith("g") else lang.lid kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */") @@ -175,8 +172,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st assert dtype is not None val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL) if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) - r[u] = ssa('val') - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};") + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};") elif uop == UOps.STORE: if len(vin) == 2: kk(f"{r[vin[0]]} = {r[vin[1]]};") @@ -185,11 +181,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)) elif uop == UOps.CAST and dtype is not None and dtype.sz > 1: val = lang.render_cast([r[x] for x in vin], dtype) - if child_count[u] <= 1: - r[u] = val - else: - r[u] = ssa('cast') - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};") + if child_count[u] <= 1: r[u] = val + else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};") elif uop == UOps.DEFINE_LOCAL: if lang.external_local_bufs: prekernel.append(lang.render_local(args[0], args[1]))