mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
reorder UOps.DEFINE_VAR in runtime [run_process_replay] (#5659)
prep rewrite SPECIAL using DEFINE_VAR
This commit is contained in:
@@ -172,6 +172,10 @@ class PTXRenderer(Renderer):
|
||||
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
|
||||
r[u] = "%" + args[1]
|
||||
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
r[u] = f"%{args.expr}"
|
||||
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.CONST:
|
||||
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
||||
else: r[u] = const(args, dtype, mov=True)
|
||||
@@ -205,10 +209,6 @@ class PTXRenderer(Renderer):
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
r[u] = f"%{args.expr}"
|
||||
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", dtype))
|
||||
r[u] = f"%{nm}"
|
||||
|
||||
@@ -141,6 +141,11 @@ class CStyleLanguage(Renderer):
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
|
||||
seen_vars.add(args.expr)
|
||||
bufs.append((args.expr, (dtype,False)))
|
||||
r[u] = args.expr
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
@@ -162,11 +167,6 @@ class CStyleLanguage(Renderer):
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
kk(self.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
|
||||
seen_vars.add(args.expr)
|
||||
bufs.append((args.expr, (dtype,False)))
|
||||
r[u] = args.expr
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
||||
r[u] = nm
|
||||
|
||||
@@ -84,10 +84,8 @@ class PythonProgram:
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
ul[i] = [pvals.pop(0)] * warp_size
|
||||
elif uop is UOps.SPECIAL:
|
||||
if arg[1][0] == 'g':
|
||||
ul[i] = [idxs[2-arg[0]]] * warp_size
|
||||
elif arg[1][0] == 'l':
|
||||
ul[i] = [x[2-arg[0]] for x in warp]
|
||||
if arg[1][0] == 'g': ul[i] = [idxs[2-arg[0]]] * warp_size
|
||||
elif arg[1][0] == 'l': ul[i] = [x[2-arg[0]] for x in warp]
|
||||
elif uop is UOps.CONST:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
|
||||
Reference in New Issue
Block a user