diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 17b711a763..7ff41c4e4e 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -227,9 +227,9 @@ class PTXRenderer(Renderer): for i in range(0, len(r[vv]), 2): wmma.append(ssa("wmma", dtype="b32")) kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};') - r[u] = r[vin[2]] + r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\ - {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};') + {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[vin[2]])}}};') else: raise NotImplementedError(f"no code for {uop}") return self.render_kernel(kernel, name, bufs, c.items())