ops_python: add HIP tensor core mock and refactor METAL (#3354)

* ops_python: add HIP tensor core mock and refactor METAL

* Add tests to CI

* add DEBUG=2 to full tests

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Francis Lam
2024-02-09 03:46:06 -08:00
committed by GitHub
parent b385234961
commit ce21fdfb67
3 changed files with 45 additions and 25 deletions

View File

@@ -4,7 +4,7 @@
from typing import Tuple, List, Optional, Any, Dict
import pickle, base64, itertools, time, math
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Allocator, Compiler
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
@@ -148,27 +148,36 @@ class PythonProgram:
ul[i] = inp[0][arg]
elif uop is UOps.WMMA:
# here are the models for the WMMA instruction on the different hardware
if arg == '__metal_wmma<float2,simdgroup_float8x8,float2>':
order = [0, 32, 1, 33, 8, 40, 9, 41,
2, 34, 3, 35, 10, 42, 11, 43,
4, 36, 5, 37, 12, 44, 13, 45,
6, 38, 7, 39, 14, 46, 15, 47,
16, 48, 17, 49, 24, 56, 25, 57,
18, 50, 19, 51, 26, 58, 27, 59,
20, 52, 21, 53, 28, 60, 29, 61,
22, 54, 23, 55, 30, 62, 31, 63]
def unswizzle(goff, x): return [x[0][goff+idx] if idx < 32 else
x[1][goff+idx-32] for idx in order]
out = inp[2][0][:], inp[2][1][:]
for goff in range(0, warp_size, 32):
m1,m2 = unswizzle(goff, inp[0]), unswizzle(goff, inp[1])
for _i in range(8):
for _j in range(8):
oidx = order[_i*8 + _j]
nval = sum(m1[_i*8+_k] * m2[_k*8+_j] for _k in range(8))
if oidx < 32: out[0][goff+oidx] += nval
else: out[1][goff+oidx-32] += nval
ul[i] = out
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
for goff in range(0, warp_size, WARP_THREADS):
for lane_id in range(WARP_THREADS):
for elem_idx in range(NUM_C): # calculate new muls and add to acc
(c_i, c_j) = c_map(lane_id, elem_idx)
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
return out
if arg.startswith('__metal_wmma'):
def a_b_elem(x, i, j, goff): # A (2 elements on 32 threads): row major
return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
def c_map(lane, elem): # (i, j), C, D (2 elements on 32 threads): row major same as A/B
return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif arg == '__builtin_amdgcn_wmma_f32_16x16x16_f16_w32' or arg == '__hip_wmma_f16_f16':
def a_elem(x, i, j, goff): # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
return x[i][goff+j]
def b_elem(x, i, j, goff): # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
return a_elem(x, j, i, goff)
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
else:
raise Exception(f"unimplemented tensor core {arg}")
elif uop is UOps.ALU:
@@ -180,7 +189,8 @@ class PythonProgram:
return time.perf_counter() - st
class PythonCompiler(Compiler):
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else LinearizerOptions()
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else LinearizerOptions())
def render(self, name:str, uops:List[UOp]) -> str:
lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()