diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index f3b7dccf12..1f4ef3ff54 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -121,12 +121,9 @@ class PythonProgram: elif uop is Ops.WMMA: # here are the models for the WMMA instruction on the different hardware def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map): - assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}" - assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}" - assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}" - assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA" - assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA" - assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA" + for cc, tinp, num in zip(("A", "B", "C"), inp, (NUM_A, NUM_B, NUM_C)): + assert len(tinp) == num, f"{cc} must have {num} elements per thread, it has {len(tinp)}" + assert len(flatten(tinp)) == num * warp_size, f"WMMA must have {num * warp_size} total elements for {cc} in WMMA" assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads" out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)] for goff in range(0, warp_size, WARP_THREADS): @@ -203,5 +200,4 @@ class PythonAllocator(Allocator): def _copyout(self, dest:memoryview, src): dest[:] = src class PythonDevice(Compiled): - def __init__(self, device:str): - super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram) + def __init__(self, device:str): super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)