mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Fix tensor cores in PTX (#4698)
This commit is contained in:
@@ -227,9 +227,9 @@ class PTXRenderer(Renderer):
|
||||
for i in range(0, len(r[vv]), 2):
|
||||
wmma.append(ssa("wmma", dtype="b32"))
|
||||
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
||||
r[u] = r[vin[2]]
|
||||
r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
||||
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
|
||||
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[vin[2]])}}};')
|
||||
else: raise NotImplementedError(f"no code for {uop}")
|
||||
|
||||
return self.render_kernel(kernel, name, bufs, c.items())
|
||||
|
||||
Reference in New Issue
Block a user