mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-16 01:26:29 -05:00
no loop assert in ops_python [pr] (#7834)
This commit is contained in:
@@ -121,12 +121,9 @@ class PythonProgram:
|
||||
elif uop is Ops.WMMA:
|
||||
# here are the models for the WMMA instruction on the different hardware
|
||||
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
||||
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
|
||||
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}"
|
||||
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}"
|
||||
assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
|
||||
assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
|
||||
assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
|
||||
for cc, tinp, num in zip(("A", "B", "C"), inp, (NUM_A, NUM_B, NUM_C)):
|
||||
assert len(tinp) == num, f"{cc} must have {num} elements per thread, it has {len(tinp)}"
|
||||
assert len(flatten(tinp)) == num * warp_size, f"WMMA must have {num * warp_size} total elements for {cc} in WMMA"
|
||||
assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
|
||||
out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
|
||||
for goff in range(0, warp_size, WARP_THREADS):
|
||||
@@ -203,5 +200,4 @@ class PythonAllocator(Allocator):
|
||||
def _copyout(self, dest:memoryview, src): dest[:] = src
|
||||
|
||||
class PythonDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
|
||||
def __init__(self, device:str): super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
|
||||
|
||||
Reference in New Issue
Block a user