diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 953c19de9c..c21f857d40 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -133,7 +133,7 @@ string_rewrite = PatternMatcher([ (UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"), (UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"), (UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))), - (UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier), + (UPat(Ops.BARRIER), lambda ctx: ctx.barrier), (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"), ]) @@ -180,7 +180,7 @@ class PTXRenderer(Renderer): self.uops = uops def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str: - nonlocal c, r + nonlocal c prefix += f"_{dtype if dtype is not None else self.types[unwrap(u).dtype.base]}_" c[prefix] += 1 return f"%{prefix}{c[prefix]-1}" @@ -230,7 +230,7 @@ class PTXRenderer(Renderer): [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]] r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None), - Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]), + Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local", self.types[dtypes.ulong]), Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None)) if prefix: r[u] = ssa(prefix, u, dtype)