diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d64a3fbbff..1bebabf62f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: pass_filenames: false - id: tests name: subset of tests - entry: env PYTHONPATH="." python3 -m pytest -n=8 test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_assign.py + entry: env OMP_NUM_THREADS=1 PYTHONPATH="." python3 -m pytest -n=8 test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_assign.py language: system always_run: true - pass_filenames: false \ No newline at end of file + pass_filenames: false diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 0815596021..b8ed1654d2 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -3,7 +3,7 @@ # works to test the tensor cores, and all the uops in general # this is the (living) definition of uops from typing import Any, TYPE_CHECKING, cast -import pickle, base64, itertools, time, struct, sys +import pickle, base64, itertools, time, struct, sys, functools from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE from tinygrad.device import Compiled, Compiler, Allocator @@ -36,6 +36,20 @@ def _store(m, i, v, dtype: DType): if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}") m[i] = to_storage_scalar(v, dtype) +# here are the models for the WMMA instruction on the different hardware +def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map): + for cc, tinp, num in zip(("A", "B", "C"), inp, (NUM_A, NUM_B, NUM_C)): + assert len(tinp) == num, f"{cc} must have {num} elements per thread, it has {len(tinp)}" + assert len(flatten(tinp)) == num * warp_size, f"WMMA must have {num * warp_size} total elements for {cc} in WMMA" + assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads" + out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)] + for goff in range(0, warp_size, WARP_THREADS): + for lane_id in range(WARP_THREADS): + for elem_idx in range(NUM_C): # calculate new muls and add to acc + (c_i, c_j) = c_map(lane_id, elem_idx) + out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K)) + return out + class PythonProgram: def __init__(self, name:str, lib:bytes): self.uops: list[tuple[Ops, DType|None, list[int], Any]] = pickle.loads(lib) @@ -125,23 +139,10 @@ class PythonProgram: ul[i] = load(inp, 0, dtype) elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)] elif uop is Ops.WMMA: - # here are the models for the WMMA instruction on the different hardware - def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map): - for cc, tinp, num in zip(("A", "B", "C"), inp, (NUM_A, NUM_B, NUM_C)): - assert len(tinp) == num, f"{cc} must have {num} elements per thread, it has {len(tinp)}" - assert len(flatten(tinp)) == num * warp_size, f"WMMA must have {num * warp_size} total elements for {cc} in WMMA" - assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads" - out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)] - for goff in range(0, warp_size, WARP_THREADS): - for lane_id in range(WARP_THREADS): - for elem_idx in range(NUM_C): # calculate new muls and add to acc - (c_i, c_j) = c_map(lane_id, elem_idx) - out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K)) - return out - first_src_dtype = self.uops[idp[0]][1] assert isinstance(first_src_dtype, DType) # mypy dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5] + wmma_helper = functools.partial(generic_wmma_helper, inp, warp_size) # TODO: refactor these to a shared TensorCoreLayout in kernel.py if device == "METAL": # A (2 elements on 32 threads): row major @@ -203,7 +204,7 @@ class PythonProgram: ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map) elif device == "CPU": def elem(x, col, row, _): return x[col+row][0] # k is always 0 - def c_map(_, elem): return (elem%16, elem//16) + def c_map(lane, elem): return (elem%16, elem//16) ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map) else: raise NotImplementedError(f"unimplemented tensor core {arg}") elif uop in GroupOp.ALU: