From 5e01cc299bd1110fe7b1463985870d290a4c4bd5 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 28 Oct 2025 22:49:55 +0800 Subject: [PATCH] zero len ranges fail (#12974) * zero len ranges fail * fix Python backend * fix llvm * fix ptx * yolo fix nir * this works... * always store... * always store... * Revert "always store..." This reverts commit 0816cf344d94466ea889b7b416dccb61a20f3739. --- .github/workflows/test.yml | 2 +- test/test_uops.py | 7 ++ tinygrad/renderer/llvmir.py | 20 ++++-- tinygrad/renderer/nir.py | 11 +-- tinygrad/renderer/ptx.py | 6 +- tinygrad/runtime/ops_python.py | 122 ++++++++++++++++----------------- 6 files changed, 93 insertions(+), 75 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f6f33ad53d..eaef0445ff 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -309,7 +309,7 @@ jobs: key: spec-unit deps: testing_unit - name: Test SPEC=2 - run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 40 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }} + run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }} fuzzing: name: Fuzzing diff --git a/test/test_uops.py b/test/test_uops.py index c55ae0ff27..5bf49bd9a4 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -559,5 +559,12 @@ class TestUOpRender(unittest.TestCase): u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2))) self.assertEqual(u.render(), "(0, 1, 2)") +class TestZeroRange(unittest.TestCase): + def test_reduce_variable(self): + for i in range(3,-1,-1): + v = UOp.variable("i", 0, 5).bind(i) + out = Tensor.ones(10, dtype=dtypes.int).contiguous().shrink(((0,v),)).sum() + self.assertEqual(out.item(), i) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index ce053157cb..e83e44364f 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -107,14 +107,20 @@ base_rewrite = PatternMatcher([ # range (UPat(Ops.RANGE, name="r"), lambda ctx,r: - f" br label %loop_entry_{range_str(r)}\nloop_entry_{range_str(r)}:\n" - f" br label %loop_body_{range_str(r)}\nloop_body_{range_str(r)}:\n" - f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_latch_{range_str(r)} ]"), - (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r")), name="x"), lambda ctx,x,r: - f" br label %loop_latch_{range_str(r)}\nloop_latch_{range_str(r)}:\n" + f" br label %loop_entry_{range_str(r)}\n" + f"loop_entry_{range_str(r)}:\n" + f" br label %loop_latch_{range_str(r)}\n" + f"loop_latch_{range_str(r)}:\n" + f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_footer_{range_str(r)} ]\n" f" {ctx[r]}phi = add {ldt(r.dtype)} {ctx[r]}, 1\n" - f" {ctx[x]} = icmp ult {ldt(r.dtype)} {ctx[r]}phi, {ctx[r.src[0]]}\n" - f" br i1 {ctx[x]}, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\nloop_exit_{range_str(r)}:"), + f" {ctx[r]}cmp = icmp ult {ldt(r.dtype)} {ctx[r]}, {ctx[r.src[0]]}\n" + f" br i1 {ctx[r]}cmp, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\n" + f"loop_body_{range_str(r)}:"), + (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda r: + f" br label %loop_footer_{range_str(r)}\n" + f"loop_footer_{range_str(r)}:\n" + f" br label %loop_latch_{range_str(r)}\n" + f"loop_exit_{range_str(r)}:"), # if (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"), diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 9282e2034e..99c51531df 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -3,7 +3,7 @@ from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes from tinygrad.helpers import DEBUG, OSX, unwrap from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer -from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat +from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str import tinygrad.runtime.autogen.mesa as mesa import base64, ctypes, ctypes.util, struct, functools, inspect @@ -182,14 +182,17 @@ class NIRRenderer(Renderer): self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long) self.b.shader.contents.info.shared_size += u.dtype.nbytes() elif u.op == Ops.RANGE: - ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{u.arg[0]}".encode()).contents)) + ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents)) nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype) mesa.nir_push_loop(self.b) self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype) + nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break)) elif u.op == Ops.END: r = u.src[1] - nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype)), self.r[r.src[0]]), - functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, r.dtype), lambda: njump(self.b, mesa.nir_jump_break)) + next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype)) + # TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone + nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break)) + nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype), mesa.nir_pop_loop(self.b, None) else: if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}") diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 2cb6ef683f..b94ff21d61 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -119,8 +119,12 @@ string_rewrite = PatternMatcher([ if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"), # simple (UPat(Ops.DEFINE_REG, src=()), lambda ctx: []), - (UPat(Ops.RANGE, name="r"), lambda ctx, r: [f"mov.u32 {ctx.r[r]}, 0;", "LOOP_" + f"{ctx.r[r][1:]}:"]), + (UPat(Ops.RANGE, name="r"), lambda ctx, r: [ + f"mov.u32 {ctx.r[r]}, -1;", + f"bra END_{ctx.r[r][1:]};", + "LOOP_" + f"{ctx.r[r][1:]}:"]), (UPat(Ops.END, name="x", src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda ctx, x, r: [ + "END_" + f"{ctx.r[r][1:]}:", ctx.code_for_op[Ops.ADD](ctx.r[r], ctx.r[r], "1", dtypes.int, ctx.types[dtypes.int]), ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[r], ctx.r[r.src[0]], dtypes.int, ctx.types[dtypes.int]), f"@{ctx.r[x]} bra LOOP_{ctx.r[r][1:]};"]), diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index b8ed1654d2..6a98a23fb3 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -52,41 +52,38 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_ class PythonProgram: def __init__(self, name:str, lib:bytes): - self.uops: list[tuple[Ops, DType|None, list[int], Any]] = pickle.loads(lib) + self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib) def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): st = time.perf_counter() warp = list(itertools.product(*[range(x) for x in local_size[::-1]])) warp_size = len(warp) + void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE} + loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END} for idxs in itertools.product(*[range(x) for x in global_size[::-1]]): - ul: dict[int, Any] = {} - dl: dict[int, DType] = {} + values: dict[int, Any] = {} pbufs: list[memoryview] = list(bufs) pvals: list[int] = list(vals) i = 0 - loop_ends: dict[int, int] = {} while i < len(self.uops): - uop, dtype, idp, arg = self.uops[i] - void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE} - inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops] - dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops] - if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp) + uop, dtype, srcs, arg = self.uops[i] + src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops] + src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops] + if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes) if uop is Ops.END: - loop_ends[idp[1]] = i - i = idp[1] + i = srcs[1] continue if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP): # in the python emulator, the warp is always in sync i += 1 continue assert dtype is not None, f"{uop} is missing a dtype" - dl[i] = dtype if uop is Ops.STORE: - for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]): - for (m,o,g),v in zip(inp[0], val): - if g: _store(m, o+j, v, dtp[1].scalar()) + for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]): + for (m,o,g),v in zip(src_values[0], val): + if g: _store(m, o+j, v, src_dtypes[1].scalar()) i += 1 continue - if uop is Ops.AFTER: ul[i] = inp[0] + if uop is Ops.AFTER: values[i] = src_values[0] elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: assert isinstance(dtype, PtrDType), dtype storage_fmt = storage_fmt_for_dtype(dtype.base.scalar()) @@ -94,72 +91,73 @@ class PythonProgram: if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e" if uop is Ops.DEFINE_REG: # REGs are per thread - ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] + values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] else: buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0) - ul[i] = [buf.cast(storage_fmt)] * warp_size + values[i] = [buf.cast(storage_fmt)] * warp_size elif uop is Ops.DEFINE_VAR: - ul[i] = [pvals.pop(0)] * warp_size + values[i] = [pvals.pop(0)] * warp_size elif uop is Ops.SPECIAL: - if arg[0] == 'g': ul[i] = [idxs[2-int(arg[-1])]] * warp_size - elif arg[0] == 'l': ul[i] = [x[2-int(arg[-1])] for x in warp] - elif uop is Ops.CONST: ul[i] = [arg] * warp_size + if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size + elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp] + elif uop is Ops.CONST: values[i] = [arg] * warp_size elif uop is Ops.INDEX: ret:list = [] - if isinstance(dtp[0], ImageDType): - for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]): - if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None)) - else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4)) + if isinstance(src_dtypes[0], ImageDType): + for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]): + if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None)) + else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4)) else: - for m,o in zip(inp[0], inp[1]): ret.append((m,o)) - ul[i] = [(m,o,g) for (m,o),g in zip(ret, inp[2] if len(inp) == 3 else [True]*len(ret))] # set the gate last + for m,o in zip(src_values[0], src_values[1]): ret.append((m,o)) + values[i] = [(m,o,g) for (m,o),g in zip(ret, src_values[2] if len(src_values) == 3 else [True]*len(ret))] # set the gate last elif uop is Ops.CAST and isinstance(dtype, PtrDType): - ul[i] = inp[0] + values[i] = src_values[0] elif uop is Ops.RANGE: - if i not in ul: ul[i] = [0] * warp_size + if i not in values: values[i] = [0] * warp_size else: - for j in range(len(ul[i])): - ul[i][j] += 1 - if ul[i][0] == inp[0][0]: - del ul[i] - i = loop_ends[i] + 1 - continue - elif uop is Ops.VECTORIZE: ul[i] = inp + for j in range(len(values[i])): + values[i][j] += 1 + if values[i][0] == src_values[0][0]: + del values[i] + i = loop_ends[i] + 1 + continue + elif uop is Ops.VECTORIZE: values[i] = src_values elif uop is Ops.BITCAST: - packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(dtp[0].scalar()), *[to_storage_scalar(x, dtp[0].scalar()) for x in inp[0]]) - ul[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed)) - ul[i] = [from_storage_scalar(x, dtype.scalar()) for x in ul[i]] + packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(src_dtypes[0].scalar()), + *[to_storage_scalar(x, src_dtypes[0].scalar()) for x in src_values[0]]) + values[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed)) + values[i] = [from_storage_scalar(x, dtype.scalar()) for x in values[i]] elif uop is Ops.CAST: - ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]] + values[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in src_values[0]] elif uop is Ops.LOAD: if dtype.count > 1: - ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j, dtype.scalar()) \ - for j in range(dtype.count)] + values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \ + for i in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)] else: - ul[i] = load(inp, 0, dtype) - elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)] + values[i] = load(src_values, 0, dtype) + elif uop is Ops.GEP: values[i] = src_values[0][get_single_element(arg)] elif uop is Ops.WMMA: - first_src_dtype = self.uops[idp[0]][1] + first_src_dtype = self.uops[srcs[0]][1] assert isinstance(first_src_dtype, DType) # mypy dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5] - wmma_helper = functools.partial(generic_wmma_helper, inp, warp_size) + wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size) # TODO: refactor these to a shared TensorCoreLayout in kernel.py if device == "METAL": # A (2 elements on 32 threads): row major def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16] # (i, j), C, D (2 elements on 32 threads): row major same as A/B def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) - ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) + values[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) elif device == "AMD" and threads == 64: def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row] def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem) - ul[i] = wmma_helper(64, dims[2], len(inp[0]), len(inp[1]), len(inp[2]), a_elem, b_elem, c_map) - elif device == "AMD" and len(inp[0]) == 8: # RDNA4 + values[i] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map) + elif device == "AMD" and len(src_values[0]) == 8: # RDNA4 def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]] def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem) - ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map) + values[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map) elif device == "AMD": # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 def a_elem(x, k, row, goff): @@ -168,7 +166,7 @@ class PythonProgram: # B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15 def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major - ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) + values[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) elif device == "CUDA": # (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8 def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8) @@ -176,22 +174,22 @@ class PythonProgram: if dims == (8,16,16): def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4] def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4] - ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map) + values[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map) elif dims == (8,16,32): def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4] def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4] - ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map) + values[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map) elif dims == (8,16,8) and dtype_in == dtypes.half: def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4] def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4] - ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) + values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) elif dims == (8,16,8) and dtype_in == dtypes.float: def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4] def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4] - ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) + values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map) else: raise NotImplementedError(f"unimplemented tensor core {arg}") elif device == "INTEL": @@ -201,17 +199,17 @@ class PythonProgram: def b_elem(x, col, k, goff): return x[k][goff+col] # C, D (8 elements on 8 threads) def c_map(lane, elem): return (lane, elem) - ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map) + values[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map) elif device == "CPU": def elem(x, col, row, _): return x[col+row][0] # k is always 0 def c_map(lane, elem): return (elem%16, elem//16) - ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map) + values[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map) else: raise NotImplementedError(f"unimplemented tensor core {arg}") elif uop in GroupOp.ALU: - assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}" - assert all_same([dtype] + dtp) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}" - ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)] - assert i in ul, (uop, dtype, idp, arg) + assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}" + assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}" + values[i] = [exec_alu(uop, dtype, p) for p in zip(*src_values)] + assert i in values, (uop, dtype, srcs, arg) i += 1 return time.perf_counter() - st