mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fold using outerworld range (#13286)
* scan using outerworld range * almost * sched * simple range * mypy * woooo outer range * spec passes * print the numbers * lol it runs * real test
This commit is contained in:
@@ -11,6 +11,52 @@ class TestOuterworldReduce(unittest.TestCase):
|
||||
t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
|
||||
self.assertListEqual(t.tolist(), [5.,5.,5.,5.,5.])
|
||||
|
||||
# TODO: delete test_outerworld_range?
|
||||
class TestOuterRange(unittest.TestCase):
|
||||
def test_simple_range(self):
|
||||
a = Tensor.ones(10).contiguous()
|
||||
acc = Tensor.zeros().contiguous()
|
||||
Tensor.realize(a, acc)
|
||||
|
||||
# this is fold
|
||||
i = UOp.range(10, -100, AxisType.OUTER)
|
||||
acc_i = acc.uop.after(i)
|
||||
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[vi].uop).end(i)))
|
||||
out.realize()
|
||||
assert out.item() == 10.0
|
||||
|
||||
def test_inner_range(self):
|
||||
a = Tensor.ones(10, 10).contiguous()
|
||||
acc = Tensor.zeros(10).contiguous()
|
||||
Tensor.realize(a, acc)
|
||||
|
||||
# this is fold
|
||||
i = UOp.range(10, -100, AxisType.OUTER)
|
||||
acc_i = acc.uop.after(i)
|
||||
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[:, vi].uop).end(i)))
|
||||
out.realize()
|
||||
assert all(x == 10.0 for x in out.tolist())
|
||||
|
||||
def test_range_matmul(self):
|
||||
vec = Tensor.randn(1, 10).realize()
|
||||
mats = Tensor.randn(3, 10, 10).realize()
|
||||
|
||||
# 3 matmuls in "scan"
|
||||
ref = ((vec @ mats[0]) @ mats[1]) @ mats[2]
|
||||
ref.realize()
|
||||
|
||||
# 3 matmuls with outer world range
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
vec_i = Tensor(vec.uop.after(i))
|
||||
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||
out = Tensor(vec.uop.after(vec_i.uop.store((vec_i.contiguous() @ mats[vi]).uop).end(i)))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||
|
||||
class TestOuterworld(unittest.TestCase):
|
||||
def test_range_plus_1(self):
|
||||
t = Tensor.arange(100).reshape(10,10).realize()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import cast
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, replace
|
||||
from collections import deque, defaultdict
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
@@ -13,6 +13,7 @@ class ScheduleItem:
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
bound_ranges: tuple[UOp, ...] = ()
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
@@ -25,7 +26,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
if u.op is not Ops.AFTER: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
|
||||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
for s in k.src:
|
||||
if k.op is Ops.RANGE: continue
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
if s.op is Ops.AFTER:
|
||||
children[s.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
@@ -39,16 +41,19 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
elif s.op is Ops.BUFFER:
|
||||
pass # a BUFFER is already realized, nothing to do here
|
||||
elif s.op is Ops.BIND:
|
||||
var, val = s.unbind()
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
# for RANGE this is in fixedvars
|
||||
if s.src[1].op is not Ops.RANGE:
|
||||
var, val = s.unbind()
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
||||
|
||||
# linearize KERNEL UOps into ScheduleItems in BFS order
|
||||
|
||||
def _heuristic(k: UOp):
|
||||
if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000
|
||||
if k.op is Ops.KERNEL and k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]):
|
||||
return 1000
|
||||
return 0
|
||||
|
||||
last_heuristic: int = 0
|
||||
@@ -57,27 +62,53 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
for k,v in in_degree.items():
|
||||
if v == 0: queues[_heuristic(k)].append(k)
|
||||
|
||||
schedule: list[ScheduleItem] = []
|
||||
schedule: list[ScheduleItem|UOp] = []
|
||||
while last_queue or any(queues.values()):
|
||||
if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
|
||||
k = last_queue.popleft()
|
||||
ast = k.arg.ast
|
||||
# create subbuffers if needed
|
||||
if ast.op is Ops.BUFFER_VIEW:
|
||||
base = k.src[1].buf_uop.buffer
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
|
||||
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}))
|
||||
k = rk = last_queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.KERNEL:
|
||||
ast = k.arg.ast
|
||||
# create subbuffers if needed
|
||||
if ast.op is Ops.BUFFER_VIEW:
|
||||
base = k.src[1].buf_uop.buffer
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and s.src[1].op is Ops.RANGE)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
|
||||
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}, bound_ranges=bound_ranges))
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata, bound_ranges=bound_ranges))
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
|
||||
raise RuntimeError(f"can't schedule {k.op}")
|
||||
for x in children[k]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
|
||||
|
||||
return schedule, var_vals
|
||||
# expand the ranges in the schedule
|
||||
real_schedule: list[ScheduleItem] = []
|
||||
sched_ptr = 0
|
||||
in_ranges = {}
|
||||
range_ptrs = {}
|
||||
while sched_ptr < len(schedule):
|
||||
si = schedule[sched_ptr]
|
||||
if isinstance(si, UOp):
|
||||
if si.op is Ops.RANGE:
|
||||
in_ranges[si] = 0
|
||||
range_ptrs[si] = sched_ptr + 1
|
||||
elif si.op is Ops.END:
|
||||
if in_ranges[si.src[1]] < si.src[1].vmax:
|
||||
in_ranges[si.src[1]] += 1
|
||||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
|
||||
sched_ptr += 1
|
||||
return real_schedule, var_vals
|
||||
|
||||
@@ -24,8 +24,8 @@ def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize SINK src
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize COPY/BUFFER_VIEW/CONTIGUOUS
|
||||
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS}, name="tr"), realize),
|
||||
# always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE
|
||||
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE}, name="tr"), realize),
|
||||
# realize srcs of COPY, MSELECT, MSTACK
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||
# realize ASSIGN and input to assign (might be optimized out)
|
||||
@@ -51,21 +51,25 @@ class IndexingContext:
|
||||
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
|
||||
|
||||
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
||||
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
|
||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.AFTER}: return None
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
elif s in ctx.realize_map:
|
||||
realized_ranges = ctx.realize_map[s]
|
||||
assert isinstance(realized_ranges, list), "realize map must contain range list"
|
||||
closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges])
|
||||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||
if s.op is Ops.STORE:
|
||||
# add the ends if this is a store
|
||||
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
|
||||
del ctx.realize_map[s]
|
||||
else:
|
||||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||
new_srcs.append(new_src)
|
||||
# NOTE: do we need this?
|
||||
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
|
||||
|
||||
@@ -471,6 +471,9 @@ pm_add_range_tags = PatternMatcher([
|
||||
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
|
||||
|
||||
# ends of outer range don't go in kernels
|
||||
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
|
||||
|
||||
# local kernel rewrite
|
||||
lctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
|
||||
@@ -504,7 +507,7 @@ def tag_uop(ctx:list[UOp], x:UOp):
|
||||
return x.replace(tag=(len(ctx)-1,))
|
||||
add_tags = PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL,
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
|
||||
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
||||
])
|
||||
|
||||
@@ -40,6 +40,9 @@ shared_spec = PatternMatcher([
|
||||
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
])
|
||||
|
||||
# ***** UOp spec in the Tensor graph *****
|
||||
@@ -109,6 +112,10 @@ _tensor_spec = PatternMatcher([
|
||||
|
||||
# AFTER if things were kernelized
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||
|
||||
# Tensor range bind / store
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True)
|
||||
])+movement_ops+shared_spec
|
||||
|
||||
tensor_spec = PatternMatcher([
|
||||
@@ -128,9 +135,6 @@ shared_codegen_spec = PatternMatcher([
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Defines|{Ops.AFTER}),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.GROUP, dtypes.void), lambda: True),
|
||||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
|
||||
# WMMA has a <a, b, acc>
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||
|
||||
|
||||
@@ -64,6 +64,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
|
||||
for u in toposort:
|
||||
if u in excluded: continue
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
|
||||
Reference in New Issue
Block a user