mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
debug pr
This commit is contained in:
@@ -180,33 +180,33 @@ class PTXRenderer(Renderer):
|
||||
if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType):
|
||||
r[u] = r[u.src[0]]
|
||||
continue
|
||||
# r[u] = ssa('cast', u, self.types[u.dtype])
|
||||
# elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
|
||||
# elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
|
||||
# elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
|
||||
# elif u.op is Ops.DEFINE_ACC:
|
||||
# r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
|
||||
r[u] = ssa('cast', u, self.types[u.dtype])
|
||||
elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
|
||||
elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
|
||||
elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
|
||||
elif u.op is Ops.DEFINE_ACC:
|
||||
r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
|
||||
elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
|
||||
elif u.op is Ops.DEFINE_VAR:
|
||||
bufs.append((u.arg[0], u.dtype))
|
||||
# r[u] = ssa("dat", u, self.types[u.dtype])
|
||||
# elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
|
||||
r[u] = ssa("dat", u, self.types[u.dtype])
|
||||
elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
|
||||
elif u.op is Ops.LOAD:
|
||||
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
|
||||
# elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
|
||||
elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
|
||||
elif u.op is Ops.DEFINE_GLOBAL:
|
||||
bufs.append((f"data{u.arg}", u.dtype))
|
||||
# r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
|
||||
r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
|
||||
elif u.op is Ops.WMMA:
|
||||
self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
||||
prefix, *_dtype = {Ops.CAST: ("cast",), Ops.BITCAST: ("cast",), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx",),
|
||||
Ops.DEFINE_ACC: ("acc",), Ops.DEFINE_VAR: ("dat",), Ops.CONST: ("const",), Ops.DEFINE_LOCAL: ("local", "u64"),
|
||||
Ops.DEFINE_GLOBAL: ("dat", lambda u: self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
|
||||
}.get(u.op, (None,))
|
||||
if prefix is None: prefix = "alu" if u.op in GroupOp.ALU else None
|
||||
if prefix: r[u] = ssa(prefix, u, _dtype[0](u) if _dtype and callable(_dtype[0]) else _dtype[0] if _dtype else None)
|
||||
# prefix, *_dtype = {Ops.CAST: ("cast",), Ops.BITCAST: ("cast",), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx",),
|
||||
# Ops.DEFINE_ACC: ("acc",), Ops.DEFINE_VAR: ("dat",), Ops.CONST: ("const",), Ops.DEFINE_LOCAL: ("local", "u64"),
|
||||
# Ops.DEFINE_GLOBAL: ("dat", lambda u: self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
|
||||
# }.get(u.op, (None,))
|
||||
# if prefix is None: prefix = "alu" if u.op in GroupOp.ALU else None
|
||||
# if prefix: r[u] = ssa(prefix, u, _dtype[0](u) if _dtype and callable(_dtype[0]) else _dtype[0] if _dtype else None)
|
||||
|
||||
if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None:
|
||||
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}")
|
||||
|
||||
Reference in New Issue
Block a user