diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5c1df70011..6258b885dc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,6 +36,16 @@ jobs: #IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_simple_conv2d - name: Test emulated METAL tensor cores run: DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm + - name: Test emulated HIP tensor cores + run: | + PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py + - name: Full test tensor cores + run: | + DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores + DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores linter: name: Linters diff --git a/test/test_linearizer.py b/test/test_linearizer.py index fc6cfc0744..39be6b646f 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -106,7 +106,7 @@ class TestLinearizer(unittest.TestCase): def test_tensor_cores(self): if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores: self.skipTest("device doesn't have tensor cores") - for tc in tensor_cores[Device.DEFAULT]: + for tc in tensor_cores[Device[Device.DEFAULT].compiler.linearizer_opts.device]: a, b = Tensor.rand(tc.dims[1], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[0], dtype=tc.dtype_in) np_a, np_b = a.numpy(), b.numpy() r = a.matmul(b, acc_dtype=tc.dtype_out) @@ -536,7 +536,7 @@ class TestLinearizerOpts(unittest.TestCase): N = 128 Tensor.manual_seed(1552) - for tc in tensor_cores[Device.DEFAULT]: + for tc in tensor_cores[Device[Device.DEFAULT].compiler.linearizer_opts.device]: a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in) r = a.matmul(b, acc_dtype=tc.dtype_out) (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index dfea4f916a..fd98ed62a7 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -4,7 +4,7 @@ from typing import Tuple, List, Optional, Any, Dict import pickle, base64, itertools, time, math from tinygrad.dtype import DType, dtypes, ImageDType -from tinygrad.helpers import all_same, getenv +from tinygrad.helpers import all_same, getenv, flatten from tinygrad.device import Compiled, Allocator, Compiler from tinygrad.codegen.uops import UOp, UOps from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps @@ -148,27 +148,36 @@ class PythonProgram: ul[i] = inp[0][arg] elif uop is UOps.WMMA: # here are the models for the WMMA instruction on the different hardware - if arg == '__metal_wmma': - order = [0, 32, 1, 33, 8, 40, 9, 41, - 2, 34, 3, 35, 10, 42, 11, 43, - 4, 36, 5, 37, 12, 44, 13, 45, - 6, 38, 7, 39, 14, 46, 15, 47, - 16, 48, 17, 49, 24, 56, 25, 57, - 18, 50, 19, 51, 26, 58, 27, 59, - 20, 52, 21, 53, 28, 60, 29, 61, - 22, 54, 23, 55, 30, 62, 31, 63] - def unswizzle(goff, x): return [x[0][goff+idx] if idx < 32 else - x[1][goff+idx-32] for idx in order] - out = inp[2][0][:], inp[2][1][:] - for goff in range(0, warp_size, 32): - m1,m2 = unswizzle(goff, inp[0]), unswizzle(goff, inp[1]) - for _i in range(8): - for _j in range(8): - oidx = order[_i*8 + _j] - nval = sum(m1[_i*8+_k] * m2[_k*8+_j] for _k in range(8)) - if oidx < 32: out[0][goff+oidx] += nval - else: out[1][goff+oidx-32] += nval - ul[i] = out + def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map): + assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread" + assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread" + assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread" + assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA" + assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA" + assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA" + assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads" + out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)] + for goff in range(0, warp_size, WARP_THREADS): + for lane_id in range(WARP_THREADS): + for elem_idx in range(NUM_C): # calculate new muls and add to acc + (c_i, c_j) = c_map(lane_id, elem_idx) + out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K)) + return out + + if arg.startswith('__metal_wmma'): + def a_b_elem(x, i, j, goff): # A (2 elements on 32 threads): row major + return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16] + def c_map(lane, elem): # (i, j), C, D (2 elements on 32 threads): row major same as A/B + return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) + ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) + elif arg == '__builtin_amdgcn_wmma_f32_16x16x16_f16_w32' or arg == '__hip_wmma_f16_f16': + def a_elem(x, i, j, goff): # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 + assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes" + return x[i][goff+j] + def b_elem(x, i, j, goff): # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15 + return a_elem(x, j, i, goff) + def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major + ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) else: raise Exception(f"unimplemented tensor core {arg}") elif uop is UOps.ALU: @@ -180,7 +189,8 @@ class PythonProgram: return time.perf_counter() - st class PythonCompiler(Compiler): - linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else LinearizerOptions() + linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \ + (LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else LinearizerOptions()) def render(self, name:str, uops:List[UOp]) -> str: lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops] return base64.b64encode(pickle.dumps(lops)).decode()