no-op cleanups for ptx [pr] (#13938)

This commit is contained in:
chenyu
2025-12-31 17:28:39 -05:00
committed by GitHub
parent 2bb07d4824
commit 8bf7c9c1d2

View File

@@ -133,7 +133,7 @@ string_rewrite = PatternMatcher([
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
(UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
])
@@ -180,7 +180,7 @@ class PTXRenderer(Renderer):
self.uops = uops
def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
nonlocal c, r
nonlocal c
prefix += f"_{dtype if dtype is not None else self.types[unwrap(u).dtype.base]}_"
c[prefix] += 1
return f"%{prefix}{c[prefix]-1}"
@@ -230,7 +230,7 @@ class PTXRenderer(Renderer):
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]),
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local", self.types[dtypes.ulong]),
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
if prefix: r[u] = ssa(prefix, u, dtype)