move lines around in ops_python wmma (#3911)

This commit is contained in:
chenyu
2024-03-24 17:14:26 -04:00
committed by GitHub
parent e27129a798
commit 2e39f57594

View File

@@ -147,23 +147,27 @@ class PythonProgram:
return out
if arg.startswith('__metal_wmma'):
def a_b_elem(x, i, j, goff): # A (2 elements on 32 threads): row major
return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
def c_map(lane, elem): # (i, j), C, D (2 elements on 32 threads): row major same as A/B
return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
# A (2 elements on 32 threads): row major
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif arg == '__builtin_amdgcn_wmma_f32_16x16x16_f16_w32' or arg == '__hip_wmma_f16_f16':
def a_elem(x, i, j, goff): # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, i, j, goff):
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
return x[i][goff+j]
def b_elem(x, i, j, goff): # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
return a_elem(x, j, i, goff)
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff)
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg == '__cuda_mma_m16n8k16_f16_f32':
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] # A (8 elements on 32 threads)
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads)
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads)
# A (8 elements on 32 threads)
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
# B (4 elements on 32 threads)
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4]
# (i, j), C, D (4 elements on 32 threads)
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop is UOps.ALU: