diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 898294e9f4..b9a2977a95 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -147,23 +147,27 @@ class PythonProgram: return out if arg.startswith('__metal_wmma'): - def a_b_elem(x, i, j, goff): # A (2 elements on 32 threads): row major - return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16] - def c_map(lane, elem): # (i, j), C, D (2 elements on 32 threads): row major same as A/B - return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) + # A (2 elements on 32 threads): row major + def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16] + # (i, j), C, D (2 elements on 32 threads): row major same as A/B + def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) elif arg == '__builtin_amdgcn_wmma_f32_16x16x16_f16_w32' or arg == '__hip_wmma_f16_f16': - def a_elem(x, i, j, goff): # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 + # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 + def a_elem(x, i, j, goff): assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes" return x[i][goff+j] - def b_elem(x, i, j, goff): # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15 - return a_elem(x, j, i, goff) + # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15 + def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) elif arg == '__cuda_mma_m16n8k16_f16_f32': - def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] # A (8 elements on 32 threads) - def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] # B (4 elements on 32 threads) - def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) # (i, j), C, D (4 elements on 32 threads) + # A (8 elements on 32 threads) + def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] + # B (4 elements on 32 threads) + def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4] + # (i, j), C, D (4 elements on 32 threads) + def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8) ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map) else: raise NotImplementedError(f"unimplemented tensor core {arg}") elif uop is UOps.ALU: