zero len ranges fail (#12974)

* zero len ranges fail

* fix Python backend

* fix llvm

* fix ptx

* yolo fix nir

* this works...

* always store...

* always store...

* Revert "always store..."

This reverts commit 0816cf344d.
This commit is contained in:
George Hotz
2025-10-28 22:49:55 +08:00
committed by GitHub
parent e936aa7974
commit 5e01cc299b
6 changed files with 93 additions and 75 deletions

View File

@@ -309,7 +309,7 @@ jobs:
key: spec-unit
deps: testing_unit
- name: Test SPEC=2
run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 40 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
fuzzing:
name: Fuzzing

View File

@@ -559,5 +559,12 @@ class TestUOpRender(unittest.TestCase):
u = UOp(Ops.VECTORIZE, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
self.assertEqual(u.render(), "(0, 1, 2)")
class TestZeroRange(unittest.TestCase):
def test_reduce_variable(self):
for i in range(3,-1,-1):
v = UOp.variable("i", 0, 5).bind(i)
out = Tensor.ones(10, dtype=dtypes.int).contiguous().shrink(((0,v),)).sum()
self.assertEqual(out.item(), i)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -107,14 +107,20 @@ base_rewrite = PatternMatcher([
# range
(UPat(Ops.RANGE, name="r"), lambda ctx,r:
f" br label %loop_entry_{range_str(r)}\nloop_entry_{range_str(r)}:\n"
f" br label %loop_body_{range_str(r)}\nloop_body_{range_str(r)}:\n"
f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_latch_{range_str(r)} ]"),
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r")), name="x"), lambda ctx,x,r:
f" br label %loop_latch_{range_str(r)}\nloop_latch_{range_str(r)}:\n"
f" br label %loop_entry_{range_str(r)}\n"
f"loop_entry_{range_str(r)}:\n"
f" br label %loop_latch_{range_str(r)}\n"
f"loop_latch_{range_str(r)}:\n"
f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_footer_{range_str(r)} ]\n"
f" {ctx[r]}phi = add {ldt(r.dtype)} {ctx[r]}, 1\n"
f" {ctx[x]} = icmp ult {ldt(r.dtype)} {ctx[r]}phi, {ctx[r.src[0]]}\n"
f" br i1 {ctx[x]}, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\nloop_exit_{range_str(r)}:"),
f" {ctx[r]}cmp = icmp ult {ldt(r.dtype)} {ctx[r]}, {ctx[r.src[0]]}\n"
f" br i1 {ctx[r]}cmp, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\n"
f"loop_body_{range_str(r)}:"),
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda r:
f" br label %loop_footer_{range_str(r)}\n"
f"loop_footer_{range_str(r)}:\n"
f" br label %loop_latch_{range_str(r)}\n"
f"loop_exit_{range_str(r)}:"),
# if
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),

View File

@@ -3,7 +3,7 @@ from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes
from tinygrad.helpers import DEBUG, OSX, unwrap
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
import tinygrad.runtime.autogen.mesa as mesa
import base64, ctypes, ctypes.util, struct, functools, inspect
@@ -182,14 +182,17 @@ class NIRRenderer(Renderer):
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
elif u.op == Ops.RANGE:
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{u.arg[0]}".encode()).contents))
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
elif u.op == Ops.END:
r = u.src[1]
nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype)), self.r[r.src[0]]),
functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, r.dtype), lambda: njump(self.b, mesa.nir_jump_break))
next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
# TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype),
mesa.nir_pop_loop(self.b, None)
else:
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")

View File

@@ -119,8 +119,12 @@ string_rewrite = PatternMatcher([
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
# simple
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [f"mov.u32 {ctx.r[r]}, 0;", "LOOP_" + f"{ctx.r[r][1:]}:"]),
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [
f"mov.u32 {ctx.r[r]}, -1;",
f"bra END_{ctx.r[r][1:]};",
"LOOP_" + f"{ctx.r[r][1:]}:"]),
(UPat(Ops.END, name="x", src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda ctx, x, r: [
"END_" + f"{ctx.r[r][1:]}:",
ctx.code_for_op[Ops.ADD](ctx.r[r], ctx.r[r], "1", dtypes.int, ctx.types[dtypes.int]),
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[r], ctx.r[r.src[0]], dtypes.int, ctx.types[dtypes.int]),
f"@{ctx.r[x]} bra LOOP_{ctx.r[r][1:]};"]),

View File

@@ -52,41 +52,38 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_
class PythonProgram:
def __init__(self, name:str, lib:bytes):
self.uops: list[tuple[Ops, DType|None, list[int], Any]] = pickle.loads(lib)
self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib)
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
st = time.perf_counter()
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
warp_size = len(warp)
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END}
for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
ul: dict[int, Any] = {}
dl: dict[int, DType] = {}
values: dict[int, Any] = {}
pbufs: list[memoryview] = list(bufs)
pvals: list[int] = list(vals)
i = 0
loop_ends: dict[int, int] = {}
while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i]
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
uop, dtype, srcs, arg = self.uops[i]
src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops]
src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes)
if uop is Ops.END:
loop_ends[idp[1]] = i
i = idp[1]
i = srcs[1]
continue
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP):
# in the python emulator, the warp is always in sync
i += 1
continue
assert dtype is not None, f"{uop} is missing a dtype"
dl[i] = dtype
if uop is Ops.STORE:
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
for (m,o,g),v in zip(inp[0], val):
if g: _store(m, o+j, v, dtp[1].scalar())
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
for (m,o,g),v in zip(src_values[0], val):
if g: _store(m, o+j, v, src_dtypes[1].scalar())
i += 1
continue
if uop is Ops.AFTER: ul[i] = inp[0]
if uop is Ops.AFTER: values[i] = src_values[0]
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(dtype, PtrDType), dtype
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
@@ -94,72 +91,73 @@ class PythonProgram:
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
if uop is Ops.DEFINE_REG:
# REGs are per thread
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
else:
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0)
ul[i] = [buf.cast(storage_fmt)] * warp_size
values[i] = [buf.cast(storage_fmt)] * warp_size
elif uop is Ops.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size
values[i] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL:
if arg[0] == 'g': ul[i] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': ul[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: ul[i] = [arg] * warp_size
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[i] = [arg] * warp_size
elif uop is Ops.INDEX:
ret:list = []
if isinstance(dtp[0], ImageDType):
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4))
if isinstance(src_dtypes[0], ImageDType):
for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]):
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
else:
for m,o in zip(inp[0], inp[1]): ret.append((m,o))
ul[i] = [(m,o,g) for (m,o),g in zip(ret, inp[2] if len(inp) == 3 else [True]*len(ret))] # set the gate last
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))
values[i] = [(m,o,g) for (m,o),g in zip(ret, src_values[2] if len(src_values) == 3 else [True]*len(ret))] # set the gate last
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
ul[i] = inp[0]
values[i] = src_values[0]
elif uop is Ops.RANGE:
if i not in ul: ul[i] = [0] * warp_size
if i not in values: values[i] = [0] * warp_size
else:
for j in range(len(ul[i])):
ul[i][j] += 1
if ul[i][0] == inp[0][0]:
del ul[i]
i = loop_ends[i] + 1
continue
elif uop is Ops.VECTORIZE: ul[i] = inp
for j in range(len(values[i])):
values[i][j] += 1
if values[i][0] == src_values[0][0]:
del values[i]
i = loop_ends[i] + 1
continue
elif uop is Ops.VECTORIZE: values[i] = src_values
elif uop is Ops.BITCAST:
packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(dtp[0].scalar()), *[to_storage_scalar(x, dtp[0].scalar()) for x in inp[0]])
ul[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed))
ul[i] = [from_storage_scalar(x, dtype.scalar()) for x in ul[i]]
packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(src_dtypes[0].scalar()),
*[to_storage_scalar(x, src_dtypes[0].scalar()) for x in src_values[0]])
values[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed))
values[i] = [from_storage_scalar(x, dtype.scalar()) for x in values[i]]
elif uop is Ops.CAST:
ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
values[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in src_values[0]]
elif uop is Ops.LOAD:
if dtype.count > 1:
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j, dtype.scalar()) \
for j in range(dtype.count)]
values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \
for i in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)]
else:
ul[i] = load(inp, 0, dtype)
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
values[i] = load(src_values, 0, dtype)
elif uop is Ops.GEP: values[i] = src_values[0][get_single_element(arg)]
elif uop is Ops.WMMA:
first_src_dtype = self.uops[idp[0]][1]
first_src_dtype = self.uops[srcs[0]][1]
assert isinstance(first_src_dtype, DType) # mypy
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
wmma_helper = functools.partial(generic_wmma_helper, inp, warp_size)
wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size)
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
if device == "METAL":
# A (2 elements on 32 threads): row major
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
values[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif device == "AMD" and threads == 64:
def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row]
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
ul[i] = wmma_helper(64, dims[2], len(inp[0]), len(inp[1]), len(inp[2]), a_elem, b_elem, c_map)
elif device == "AMD" and len(inp[0]) == 8: # RDNA4
values[i] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map)
elif device == "AMD" and len(src_values[0]) == 8: # RDNA4
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
values[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
elif device == "AMD":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, k, row, goff):
@@ -168,7 +166,7 @@ class PythonProgram:
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
values[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif device == "CUDA":
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
@@ -176,22 +174,22 @@ class PythonProgram:
if dims == (8,16,16):
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
values[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
elif dims == (8,16,32):
def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
values[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
elif dims == (8,16,8) and dtype_in == dtypes.half:
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
elif dims == (8,16,8) and dtype_in == dtypes.float:
def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif device == "INTEL":
@@ -201,17 +199,17 @@ class PythonProgram:
def b_elem(x, col, k, goff): return x[k][goff+col]
# C, D (8 elements on 8 threads)
def c_map(lane, elem): return (lane, elem)
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
values[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif device == "CPU":
def elem(x, col, row, _): return x[col+row][0] # k is always 0
def c_map(lane, elem): return (elem%16, elem//16)
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
values[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop in GroupOp.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
assert all_same([dtype] + dtp) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
assert i in ul, (uop, dtype, idp, arg)
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}"
assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
values[i] = [exec_alu(uop, dtype, p) for p in zip(*src_values)]
assert i in values, (uop, dtype, srcs, arg)
i += 1
return time.perf_counter() - st