This commit is contained in:
Mesozoic Egg
2024-11-25 16:19:21 +08:00
parent 78d2b9fb52
commit 3aa9f77517

View File

@@ -180,33 +180,33 @@ class PTXRenderer(Renderer):
if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType):
r[u] = r[u.src[0]]
continue
# r[u] = ssa('cast', u, self.types[u.dtype])
# elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
# elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
# elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
# elif u.op is Ops.DEFINE_ACC:
# r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
r[u] = ssa('cast', u, self.types[u.dtype])
elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
elif u.op is Ops.DEFINE_ACC:
r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
elif u.op is Ops.DEFINE_VAR:
bufs.append((u.arg[0], u.dtype))
# r[u] = ssa("dat", u, self.types[u.dtype])
# elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
r[u] = ssa("dat", u, self.types[u.dtype])
elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
elif u.op is Ops.LOAD:
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
# elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
elif u.op is Ops.DEFINE_GLOBAL:
bufs.append((f"data{u.arg}", u.dtype))
# r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
elif u.op is Ops.WMMA:
self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
prefix, *_dtype = {Ops.CAST: ("cast",), Ops.BITCAST: ("cast",), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx",),
Ops.DEFINE_ACC: ("acc",), Ops.DEFINE_VAR: ("dat",), Ops.CONST: ("const",), Ops.DEFINE_LOCAL: ("local", "u64"),
Ops.DEFINE_GLOBAL: ("dat", lambda u: self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
}.get(u.op, (None,))
if prefix is None: prefix = "alu" if u.op in GroupOp.ALU else None
if prefix: r[u] = ssa(prefix, u, _dtype[0](u) if _dtype and callable(_dtype[0]) else _dtype[0] if _dtype else None)
# prefix, *_dtype = {Ops.CAST: ("cast",), Ops.BITCAST: ("cast",), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx",),
# Ops.DEFINE_ACC: ("acc",), Ops.DEFINE_VAR: ("dat",), Ops.CONST: ("const",), Ops.DEFINE_LOCAL: ("local", "u64"),
# Ops.DEFINE_GLOBAL: ("dat", lambda u: self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
# }.get(u.op, (None,))
# if prefix is None: prefix = "alu" if u.op in GroupOp.ALU else None
# if prefix: r[u] = ssa(prefix, u, _dtype[0](u) if _dtype and callable(_dtype[0]) else _dtype[0] if _dtype else None)
if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None:
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}")