reorder UOps.DEFINE_VAR in runtime [run_process_replay] (#5659)

prep rewrite SPECIAL using DEFINE_VAR
This commit is contained in:
chenyu
2024-07-23 14:32:10 -04:00
committed by GitHub
parent 199b3bf02b
commit fdc72ba102
3 changed files with 11 additions and 13 deletions

View File

@@ -172,6 +172,10 @@ class PTXRenderer(Renderer):
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
r[u] = "%" + args[1]
kernel = [f".reg .u32 %{args[1]};"] + kernel
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is UOps.CONST:
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
else: r[u] = const(args, dtype, mov=True)
@@ -205,10 +209,6 @@ class PTXRenderer(Renderer):
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", dtype))
r[u] = f"%{nm}"

View File

@@ -141,6 +141,11 @@ class CStyleLanguage(Renderer):
elif uop is UOps.SPECIAL:
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
r[u] = args[1]
elif uop is UOps.DEFINE_VAR:
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
seen_vars.add(args.expr)
bufs.append((args.expr, (dtype,False)))
r[u] = args.expr
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
@@ -162,11 +167,6 @@ class CStyleLanguage(Renderer):
elif uop is UOps.DEFINE_LOCAL:
kk(self.render_local(args[0], dtype, args[1]))
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
seen_vars.add(args.expr)
bufs.append((args.expr, (dtype,False)))
r[u] = args.expr
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
r[u] = nm

View File

@@ -84,10 +84,8 @@ class PythonProgram:
elif uop is UOps.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size
elif uop is UOps.SPECIAL:
if arg[1][0] == 'g':
ul[i] = [idxs[2-arg[0]]] * warp_size
elif arg[1][0] == 'l':
ul[i] = [x[2-arg[0]] for x in warp]
if arg[1][0] == 'g': ul[i] = [idxs[2-arg[0]]] * warp_size
elif arg[1][0] == 'l': ul[i] = [x[2-arg[0]] for x in warp]
elif uop is UOps.CONST:
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
elif uop is UOps.DEFINE_ACC: