mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
make ssa assign r[u] (#1887)
This commit is contained in:
@@ -99,18 +99,18 @@ class CStyleLanguage(NamedTuple):
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:
|
||||
local_size: List[int] = []
|
||||
kernel,prekernel = [],[]
|
||||
kernel,prekernel,bufs = [],[],[]
|
||||
#pend_close = None
|
||||
bufs = []
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
def ssa(prefix="t"):
|
||||
nonlocal c
|
||||
c[prefix] += 1
|
||||
return f"{prefix}{c[prefix]-1}"
|
||||
r: Dict[UOp, str] = {}
|
||||
def ssa(u, prefix="t"):
|
||||
nonlocal c, r
|
||||
c[prefix] += 1
|
||||
r[u]=f"{prefix}{c[prefix]-1}"
|
||||
return r[u]
|
||||
|
||||
child_count: DefaultDict[UOp, int] = defaultdict(int)
|
||||
for ru in uops:
|
||||
@@ -120,8 +120,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == UOps.LOOP:
|
||||
r[u] = ssa('ridx')
|
||||
kk(lang.render_for(r[u], r[vin[0]], r[vin[1]]))
|
||||
kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
|
||||
depth += 1
|
||||
elif uop == UOps.BARRIER:
|
||||
kk(lang.barrier)
|
||||
@@ -158,12 +157,10 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
if child_count[u] <= 1 or dtypes.is_int(dtype): # fix index rendering issue
|
||||
r[u] = val
|
||||
else:
|
||||
r[u] = ssa('alu')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
assert dtype is not None
|
||||
r[u] = ssa('acc')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.render_const(args, dtype)};")
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
|
||||
elif uop == UOps.SPECIAL:
|
||||
xid = lang.gid if args[1].startswith("g") else lang.lid
|
||||
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
|
||||
@@ -175,8 +172,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
assert dtype is not None
|
||||
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
|
||||
if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
|
||||
r[u] = ssa('val')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};")
|
||||
elif uop == UOps.STORE:
|
||||
if len(vin) == 2:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
@@ -185,11 +181,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
|
||||
elif uop == UOps.CAST and dtype is not None and dtype.sz > 1:
|
||||
val = lang.render_cast([r[x] for x in vin], dtype)
|
||||
if child_count[u] <= 1:
|
||||
r[u] = val
|
||||
else:
|
||||
r[u] = ssa('cast')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
if lang.external_local_bufs:
|
||||
prekernel.append(lang.render_local(args[0], args[1]))
|
||||
|
||||
Reference in New Issue
Block a user