make ssa assign r[u] (#1887)

This commit is contained in:
Szymon Ożóg
2023-09-21 04:20:20 +02:00
committed by GitHub
parent 9450e41f70
commit bd3444797b

View File

@@ -99,18 +99,18 @@ class CStyleLanguage(NamedTuple):
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:
local_size: List[int] = []
kernel,prekernel = [],[]
kernel,prekernel,bufs = [],[],[]
#pend_close = None
bufs = []
depth = 1
def kk(s): kernel.append(" "*depth+s)
c: DefaultDict[str, int] = defaultdict(int)
def ssa(prefix="t"):
nonlocal c
c[prefix] += 1
return f"{prefix}{c[prefix]-1}"
r: Dict[UOp, str] = {}
def ssa(u, prefix="t"):
nonlocal c, r
c[prefix] += 1
r[u]=f"{prefix}{c[prefix]-1}"
return r[u]
child_count: DefaultDict[UOp, int] = defaultdict(int)
for ru in uops:
@@ -120,8 +120,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
for u in uops:
uop,dtype,vin,args,_ = u
if uop == UOps.LOOP:
r[u] = ssa('ridx')
kk(lang.render_for(r[u], r[vin[0]], r[vin[1]]))
kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)
@@ -158,12 +157,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
if child_count[u] <= 1 or dtypes.is_int(dtype): # fix index rendering issue
r[u] = val
else:
r[u] = ssa('alu')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")
elif uop == UOps.DEFINE_ACC:
assert dtype is not None
r[u] = ssa('acc')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.render_const(args, dtype)};")
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
elif uop == UOps.SPECIAL:
xid = lang.gid if args[1].startswith("g") else lang.lid
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
@@ -175,8 +172,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
assert dtype is not None
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
r[u] = ssa('val')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};")
elif uop == UOps.STORE:
if len(vin) == 2:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
@@ -185,11 +181,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
elif uop == UOps.CAST and dtype is not None and dtype.sz > 1:
val = lang.render_cast([r[x] for x in vin], dtype)
if child_count[u] <= 1:
r[u] = val
else:
r[u] = ssa('cast')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
if child_count[u] <= 1: r[u] = val
else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
elif uop == UOps.DEFINE_LOCAL:
if lang.external_local_bufs:
prekernel.append(lang.render_local(args[0], args[1]))