diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 63f98904c9..5117f24899 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -207,9 +207,9 @@ class PTXRenderer(Renderer): elif u.op is Ops.DEFINE_GLOBAL: bufs.append((f"data{u.arg}", u.dtype)) elif u.op is Ops.WMMA: # registers for packing/unpacking input and acc - self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.dtype.scalar().itemsize)], - [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.dtype.scalar().itemsize)], - [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.src[0].dtype.scalar().itemsize)]] + self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.src[0].dtype.scalar().itemsize)], + [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)], + [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]] r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None), Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]),