diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 2d2c616189..8cc06185c9 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -172,6 +172,10 @@ class PTXRenderer(Renderer): kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};") r[u] = "%" + args[1] kernel = [f".reg .u32 %{args[1]};"] + kernel + elif uop is UOps.DEFINE_VAR: + bufs.append((args.expr, dtype)) + r[u] = f"%{args.expr}" + kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param")) elif uop is UOps.CONST: if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)] else: r[u] = const(args, dtype, mov=True) @@ -205,10 +209,6 @@ class PTXRenderer(Renderer): # TODO: we should sum these, and fetch 0xC000 from somewhere assert args[1]*dtype.itemsize <= 0xC000, "too large local" kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype)) - elif uop is UOps.DEFINE_VAR: - bufs.append((args.expr, dtype)) - r[u] = f"%{args.expr}" - kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param")) elif uop is UOps.DEFINE_GLOBAL: bufs.append((nm:=f"data{args[0]}", dtype)) r[u] = f"%{nm}" diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index f8ca7f501d..3eb2969014 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -141,6 +141,11 @@ class CStyleLanguage(Renderer): elif uop is UOps.SPECIAL: kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */") r[u] = args[1] + elif uop is UOps.DEFINE_VAR: + assert args.expr not in seen_vars, f"duplicate variable {args.expr}" + seen_vars.add(args.expr) + bufs.append((args.expr, (dtype,False))) + r[u] = args.expr elif uop is UOps.LOAD: val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) # NOTE: this relies on the load not happening if it's in the unselected branch @@ -162,11 +167,6 @@ class CStyleLanguage(Renderer): elif uop is UOps.DEFINE_LOCAL: kk(self.render_local(args[0], dtype, args[1])) r[u] = args[0] - elif uop is UOps.DEFINE_VAR: - assert args.expr not in seen_vars, f"duplicate variable {args.expr}" - seen_vars.add(args.expr) - bufs.append((args.expr, (dtype,False))) - r[u] = args.expr elif uop is UOps.DEFINE_GLOBAL: bufs.append((nm:=f"data{args[0]}", (dtype,args[1]))) r[u] = nm diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 3526a72058..f649bb44b6 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -84,10 +84,8 @@ class PythonProgram: elif uop is UOps.DEFINE_VAR: ul[i] = [pvals.pop(0)] * warp_size elif uop is UOps.SPECIAL: - if arg[1][0] == 'g': - ul[i] = [idxs[2-arg[0]]] * warp_size - elif arg[1][0] == 'l': - ul[i] = [x[2-arg[0]] for x in warp] + if arg[1][0] == 'g': ul[i] = [idxs[2-arg[0]]] * warp_size + elif arg[1][0] == 'l': ul[i] = [x[2-arg[0]] for x in warp] elif uop is UOps.CONST: ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size elif uop is UOps.DEFINE_ACC: