no loop assert in ops_python [pr] (#7834)

This commit is contained in:
George Hotz
2024-11-22 11:17:36 +08:00
committed by GitHub
parent d18b948f48
commit e39af63156

View File

@@ -121,12 +121,9 @@ class PythonProgram:
elif uop is Ops.WMMA:
# here are the models for the WMMA instruction on the different hardware
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}"
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}"
assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
for cc, tinp, num in zip(("A", "B", "C"), inp, (NUM_A, NUM_B, NUM_C)):
assert len(tinp) == num, f"{cc} must have {num} elements per thread, it has {len(tinp)}"
assert len(flatten(tinp)) == num * warp_size, f"WMMA must have {num * warp_size} total elements for {cc} in WMMA"
assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
for goff in range(0, warp_size, WARP_THREADS):
@@ -203,5 +200,4 @@ class PythonAllocator(Allocator):
def _copyout(self, dest:memoryview, src): dest[:] = src
class PythonDevice(Compiled):
def __init__(self, device:str):
super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
def __init__(self, device:str): super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)