From 00bc2b738c833549cc9d3ea60753b5bed677a94d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Thu, 23 May 2024 22:27:51 +0200 Subject: [PATCH] Fix tensor cores in PTX (#4698) --- tinygrad/renderer/assembly.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 17b711a763..7ff41c4e4e 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -227,9 +227,9 @@ class PTXRenderer(Renderer): for i in range(0, len(r[vv]), 2): wmma.append(ssa("wmma", dtype="b32")) kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};') - r[u] = r[vin[2]] + r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\ - {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};') + {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[vin[2]])}}};') else: raise NotImplementedError(f"no code for {uop}") return self.render_kernel(kernel, name, bufs, c.items())