fold using outerworld range (#13286)

* scan using outerworld range

* almost

* sched

* simple range

* mypy

* woooo outer range

* spec passes

* print the numbers

* lol it runs

* real test
This commit is contained in:
George Hotz
2025-11-14 20:43:41 -08:00
committed by GitHub
parent 567066f51f
commit 22c08b470c
6 changed files with 125 additions and 36 deletions

View File

@@ -11,6 +11,52 @@ class TestOuterworldReduce(unittest.TestCase):
t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
self.assertListEqual(t.tolist(), [5.,5.,5.,5.,5.])
# TODO: delete test_outerworld_range?
class TestOuterRange(unittest.TestCase):
def test_simple_range(self):
a = Tensor.ones(10).contiguous()
acc = Tensor.zeros().contiguous()
Tensor.realize(a, acc)
# this is fold
i = UOp.range(10, -100, AxisType.OUTER)
acc_i = acc.uop.after(i)
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[vi].uop).end(i)))
out.realize()
assert out.item() == 10.0
def test_inner_range(self):
a = Tensor.ones(10, 10).contiguous()
acc = Tensor.zeros(10).contiguous()
Tensor.realize(a, acc)
# this is fold
i = UOp.range(10, -100, AxisType.OUTER)
acc_i = acc.uop.after(i)
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[:, vi].uop).end(i)))
out.realize()
assert all(x == 10.0 for x in out.tolist())
def test_range_matmul(self):
vec = Tensor.randn(1, 10).realize()
mats = Tensor.randn(3, 10, 10).realize()
# 3 matmuls in "scan"
ref = ((vec @ mats[0]) @ mats[1]) @ mats[2]
ref.realize()
# 3 matmuls with outer world range
i = UOp.range(3, -100, AxisType.OUTER)
vec_i = Tensor(vec.uop.after(i))
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
out = Tensor(vec.uop.after(vec_i.uop.store((vec_i.contiguous() @ mats[vi]).uop).end(i)))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
class TestOuterworld(unittest.TestCase):
def test_range_plus_1(self):
t = Tensor.arange(100).reshape(10,10).realize()

View File

@@ -1,5 +1,5 @@
from typing import cast
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from collections import deque, defaultdict
from tinygrad.uop.ops import UOp, Ops, buffers
from tinygrad.device import Device, Buffer, MultiBuffer
@@ -13,6 +13,7 @@ class ScheduleItem:
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...] = ()
fixedvars: dict[str, int] = field(default_factory=dict)
bound_ranges: tuple[UOp, ...] = ()
# **** schedule linearizer
@@ -25,7 +26,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
if u.op is not Ops.AFTER: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
k = u.src[1]
in_degree.setdefault(k, 0)
for s in k.src:
if k.op is Ops.RANGE: continue
for s in k.src[0].src if k.op is Ops.END else k.src:
if s.op is Ops.AFTER:
children[s.src[1]].append(k)
in_degree[k] += 1
@@ -39,16 +41,19 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
elif s.op is Ops.BUFFER:
pass # a BUFFER is already realized, nothing to do here
elif s.op is Ops.BIND:
var, val = s.unbind()
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
# for RANGE this is in fixedvars
if s.src[1].op is not Ops.RANGE:
var, val = s.unbind()
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
var_vals[var.expr] = val
else:
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
# linearize KERNEL UOps into ScheduleItems in BFS order
def _heuristic(k: UOp):
if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000
if k.op is Ops.KERNEL and k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]):
return 1000
return 0
last_heuristic: int = 0
@@ -57,27 +62,53 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
for k,v in in_degree.items():
if v == 0: queues[_heuristic(k)].append(k)
schedule: list[ScheduleItem] = []
schedule: list[ScheduleItem|UOp] = []
while last_queue or any(queues.values()):
if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
k = last_queue.popleft()
ast = k.arg.ast
# create subbuffers if needed
if ast.op is Ops.BUFFER_VIEW:
base = k.src[1].buf_uop.buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
if any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}))
k = rk = last_queue.popleft()
if k.op is Ops.END: k = k.src[0]
if k.op is Ops.RANGE: schedule.append(k)
elif k.op is Ops.KERNEL:
ast = k.arg.ast
# create subbuffers if needed
if ast.op is Ops.BUFFER_VIEW:
base = k.src[1].buf_uop.buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and s.src[1].op is Ops.RANGE)
if any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}, bound_ranges=bound_ranges))
else:
# ONE -> ONE
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata, bound_ranges=bound_ranges))
if rk.op is Ops.END: schedule.append(rk)
else:
# ONE -> ONE
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
raise RuntimeError(f"can't schedule {k.op}")
for x in children[k]:
in_degree[x] -= 1
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
return schedule, var_vals
# expand the ranges in the schedule
real_schedule: list[ScheduleItem] = []
sched_ptr = 0
in_ranges = {}
range_ptrs = {}
while sched_ptr < len(schedule):
si = schedule[sched_ptr]
if isinstance(si, UOp):
if si.op is Ops.RANGE:
in_ranges[si] = 0
range_ptrs[si] = sched_ptr + 1
elif si.op is Ops.END:
if in_ranges[si.src[1]] < si.src[1].vmax:
in_ranges[si.src[1]] += 1
sched_ptr = range_ptrs[si.src[1]]
continue
else:
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
sched_ptr += 1
return real_schedule, var_vals

View File

@@ -24,8 +24,8 @@ def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
pm_generate_realize_map = PatternMatcher([
# always realize SINK src
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize COPY/BUFFER_VIEW/CONTIGUOUS
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS}, name="tr"), realize),
# always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE}, name="tr"), realize),
# realize srcs of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
# realize ASSIGN and input to assign (might be optimized out)
@@ -51,21 +51,25 @@ class IndexingContext:
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.AFTER}: return None
new_srcs = []
for s in x.src:
new_src = s
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
elif s in ctx.realize_map:
realized_ranges = ctx.realize_map[s]
assert isinstance(realized_ranges, list), "realize map must contain range list"
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges])
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
if s.op is Ops.STORE:
# add the ends if this is a store
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
del ctx.realize_map[s]
else:
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
new_srcs.append(new_src)
# NOTE: do we need this?
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None

View File

@@ -471,6 +471,9 @@ pm_add_range_tags = PatternMatcher([
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
# ends of outer range don't go in kernels
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
# local kernel rewrite
lctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
@@ -504,7 +507,7 @@ def tag_uop(ctx:list[UOp], x:UOp):
return x.replace(tag=(len(ctx)-1,))
add_tags = PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL,
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
])

View File

@@ -40,6 +40,9 @@ shared_spec = PatternMatcher([
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
])
# ***** UOp spec in the Tensor graph *****
@@ -109,6 +112,10 @@ _tensor_spec = PatternMatcher([
# AFTER if things were kernelized
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
# Tensor range bind / store
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True)
])+movement_ops+shared_spec
tensor_spec = PatternMatcher([
@@ -128,9 +135,6 @@ shared_codegen_spec = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines|{Ops.AFTER}),), allow_any_len=True), lambda: True),
(UPat(Ops.GROUP, dtypes.void), lambda: True),
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
# WMMA has a <a, b, acc>
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),

View File

@@ -64,6 +64,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")