mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
cleanup wmma (#12927)
* cleanup wmma * fix test_ops failures on android
This commit is contained in:
@@ -28,7 +28,7 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: tests
|
||||
name: subset of tests
|
||||
entry: env PYTHONPATH="." python3 -m pytest -n=8 test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_assign.py
|
||||
entry: env OMP_NUM_THREADS=1 PYTHONPATH="." python3 -m pytest -n=8 test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_assign.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
@@ -3,7 +3,7 @@
|
||||
# works to test the tensor cores, and all the uops in general
|
||||
# this is the (living) definition of uops
|
||||
from typing import Any, TYPE_CHECKING, cast
|
||||
import pickle, base64, itertools, time, struct, sys
|
||||
import pickle, base64, itertools, time, struct, sys, functools
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
@@ -36,6 +36,20 @@ def _store(m, i, v, dtype: DType):
|
||||
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
||||
m[i] = to_storage_scalar(v, dtype)
|
||||
|
||||
# here are the models for the WMMA instruction on the different hardware
|
||||
def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
||||
for cc, tinp, num in zip(("A", "B", "C"), inp, (NUM_A, NUM_B, NUM_C)):
|
||||
assert len(tinp) == num, f"{cc} must have {num} elements per thread, it has {len(tinp)}"
|
||||
assert len(flatten(tinp)) == num * warp_size, f"WMMA must have {num * warp_size} total elements for {cc} in WMMA"
|
||||
assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
|
||||
out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
|
||||
for goff in range(0, warp_size, WARP_THREADS):
|
||||
for lane_id in range(WARP_THREADS):
|
||||
for elem_idx in range(NUM_C): # calculate new muls and add to acc
|
||||
(c_i, c_j) = c_map(lane_id, elem_idx)
|
||||
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
|
||||
return out
|
||||
|
||||
class PythonProgram:
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
self.uops: list[tuple[Ops, DType|None, list[int], Any]] = pickle.loads(lib)
|
||||
@@ -125,23 +139,10 @@ class PythonProgram:
|
||||
ul[i] = load(inp, 0, dtype)
|
||||
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
|
||||
elif uop is Ops.WMMA:
|
||||
# here are the models for the WMMA instruction on the different hardware
|
||||
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
||||
for cc, tinp, num in zip(("A", "B", "C"), inp, (NUM_A, NUM_B, NUM_C)):
|
||||
assert len(tinp) == num, f"{cc} must have {num} elements per thread, it has {len(tinp)}"
|
||||
assert len(flatten(tinp)) == num * warp_size, f"WMMA must have {num * warp_size} total elements for {cc} in WMMA"
|
||||
assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
|
||||
out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
|
||||
for goff in range(0, warp_size, WARP_THREADS):
|
||||
for lane_id in range(WARP_THREADS):
|
||||
for elem_idx in range(NUM_C): # calculate new muls and add to acc
|
||||
(c_i, c_j) = c_map(lane_id, elem_idx)
|
||||
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
|
||||
return out
|
||||
|
||||
first_src_dtype = self.uops[idp[0]][1]
|
||||
assert isinstance(first_src_dtype, DType) # mypy
|
||||
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
|
||||
wmma_helper = functools.partial(generic_wmma_helper, inp, warp_size)
|
||||
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
||||
if device == "METAL":
|
||||
# A (2 elements on 32 threads): row major
|
||||
@@ -203,7 +204,7 @@ class PythonProgram:
|
||||
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif device == "CPU":
|
||||
def elem(x, col, row, _): return x[col+row][0] # k is always 0
|
||||
def c_map(_, elem): return (elem%16, elem//16)
|
||||
def c_map(lane, elem): return (elem%16, elem//16)
|
||||
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop in GroupOp.ALU:
|
||||
|
||||
Reference in New Issue
Block a user