diff --git a/test/test_outerworld.py b/test/test_outerworld.py index 5714cc0748..d63a367ebb 100644 --- a/test/test_outerworld.py +++ b/test/test_outerworld.py @@ -11,6 +11,52 @@ class TestOuterworldReduce(unittest.TestCase): t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD)) self.assertListEqual(t.tolist(), [5.,5.,5.,5.,5.]) +# TODO: delete test_outerworld_range? +class TestOuterRange(unittest.TestCase): + def test_simple_range(self): + a = Tensor.ones(10).contiguous() + acc = Tensor.zeros().contiguous() + Tensor.realize(a, acc) + + # this is fold + i = UOp.range(10, -100, AxisType.OUTER) + acc_i = acc.uop.after(i) + vi = UOp.variable("i", i.vmin, i.vmax).bind(i) + out = Tensor(acc.uop.after(acc_i.store(acc_i + a[vi].uop).end(i))) + out.realize() + assert out.item() == 10.0 + + def test_inner_range(self): + a = Tensor.ones(10, 10).contiguous() + acc = Tensor.zeros(10).contiguous() + Tensor.realize(a, acc) + + # this is fold + i = UOp.range(10, -100, AxisType.OUTER) + acc_i = acc.uop.after(i) + vi = UOp.variable("i", i.vmin, i.vmax).bind(i) + out = Tensor(acc.uop.after(acc_i.store(acc_i + a[:, vi].uop).end(i))) + out.realize() + assert all(x == 10.0 for x in out.tolist()) + + def test_range_matmul(self): + vec = Tensor.randn(1, 10).realize() + mats = Tensor.randn(3, 10, 10).realize() + + # 3 matmuls in "scan" + ref = ((vec @ mats[0]) @ mats[1]) @ mats[2] + ref.realize() + + # 3 matmuls with outer world range + i = UOp.range(3, -100, AxisType.OUTER) + vec_i = Tensor(vec.uop.after(i)) + vi = UOp.variable("i", i.vmin, i.vmax).bind(i) + out = Tensor(vec.uop.after(vec_i.uop.store((vec_i.contiguous() @ mats[vi]).uop).end(i))) + out.realize() + + # TODO: testing allclose + assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}" + class TestOuterworld(unittest.TestCase): def test_range_plus_1(self): t = Tensor.arange(100).reshape(10,10).realize() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 655cd0d242..dc89589479 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,5 +1,5 @@ from typing import cast -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from collections import deque, defaultdict from tinygrad.uop.ops import UOp, Ops, buffers from tinygrad.device import Device, Buffer, MultiBuffer @@ -13,6 +13,7 @@ class ScheduleItem: bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] = () fixedvars: dict[str, int] = field(default_factory=dict) + bound_ranges: tuple[UOp, ...] = () # **** schedule linearizer @@ -25,7 +26,8 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ if u.op is not Ops.AFTER: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip k = u.src[1] in_degree.setdefault(k, 0) - for s in k.src: + if k.op is Ops.RANGE: continue + for s in k.src[0].src if k.op is Ops.END else k.src: if s.op is Ops.AFTER: children[s.src[1]].append(k) in_degree[k] += 1 @@ -39,16 +41,19 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ elif s.op is Ops.BUFFER: pass # a BUFFER is already realized, nothing to do here elif s.op is Ops.BIND: - var, val = s.unbind() - assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}" - var_vals[var.expr] = val + # for RANGE this is in fixedvars + if s.src[1].op is not Ops.RANGE: + var, val = s.unbind() + assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}" + var_vals[var.expr] = val else: raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}") # linearize KERNEL UOps into ScheduleItems in BFS order def _heuristic(k: UOp): - if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000 + if k.op is Ops.KERNEL and k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): + return 1000 return 0 last_heuristic: int = 0 @@ -57,27 +62,53 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ for k,v in in_degree.items(): if v == 0: queues[_heuristic(k)].append(k) - schedule: list[ScheduleItem] = [] + schedule: list[ScheduleItem|UOp] = [] while last_queue or any(queues.values()): if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic)) - k = last_queue.popleft() - ast = k.arg.ast - # create subbuffers if needed - if ast.op is Ops.BUFFER_VIEW: - base = k.src[1].buf_uop.buffer - assert isinstance(base, Buffer), "base can't be MultiBuffer" - buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) - ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND) - if any(isinstance(x, MultiBuffer) for x in ubufs): - assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" - dnums = [x for x in ast.variables() if x.arg[0] == '_device_num'] - for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): - schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {})) + k = rk = last_queue.popleft() + if k.op is Ops.END: k = k.src[0] + if k.op is Ops.RANGE: schedule.append(k) + elif k.op is Ops.KERNEL: + ast = k.arg.ast + # create subbuffers if needed + if ast.op is Ops.BUFFER_VIEW: + base = k.src[1].buf_uop.buffer + assert isinstance(base, Buffer), "base can't be MultiBuffer" + buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) + ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND) + bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and s.src[1].op is Ops.RANGE) + if any(isinstance(x, MultiBuffer) for x in ubufs): + assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" + dnums = [x for x in ast.variables() if x.arg[0] == '_device_num'] + for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): + schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}, bound_ranges=bound_ranges)) + else: + # ONE -> ONE + schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata, bound_ranges=bound_ranges)) + if rk.op is Ops.END: schedule.append(rk) else: - # ONE -> ONE - schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata)) + raise RuntimeError(f"can't schedule {k.op}") for x in children[k]: in_degree[x] -= 1 if in_degree[x] == 0: queues[_heuristic(x)].append(x) - return schedule, var_vals + # expand the ranges in the schedule + real_schedule: list[ScheduleItem] = [] + sched_ptr = 0 + in_ranges = {} + range_ptrs = {} + while sched_ptr < len(schedule): + si = schedule[sched_ptr] + if isinstance(si, UOp): + if si.op is Ops.RANGE: + in_ranges[si] = 0 + range_ptrs[si] = sched_ptr + 1 + elif si.op is Ops.END: + if in_ranges[si.src[1]] < si.src[1].vmax: + in_ranges[si.src[1]] += 1 + sched_ptr = range_ptrs[si.src[1]] + continue + else: + real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=())) + sched_ptr += 1 + return real_schedule, var_vals diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 0919e6a30a..c77fdc16a2 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -24,8 +24,8 @@ def realize_assign(ctx:dict[UOp, None], a:UOp) -> None: pm_generate_realize_map = PatternMatcher([ # always realize SINK src (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), - # always realize COPY/BUFFER_VIEW/CONTIGUOUS - (UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS}, name="tr"), realize), + # always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE + (UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE}, name="tr"), realize), # realize srcs of COPY, MSELECT, MSTACK (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs), # realize ASSIGN and input to assign (might be optimized out) @@ -51,21 +51,25 @@ class IndexingContext: return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0) def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): - if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None - if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None + if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.AFTER}: return None new_srcs = [] for s in x.src: new_src = s - if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL): + if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}: if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) elif s in ctx.realize_map: realized_ranges = ctx.realize_map[s] assert isinstance(realized_ranges, list), "realize map must contain range list" closed_ranges = tuple([r for i,r in enumerate(ctx.range_map[s][1]) if i in realized_ranges]) - # None in the device assigns it a number later - opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL) - new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None) - if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges]) + if s.op is Ops.STORE: + # add the ends if this is a store + new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE]) + del ctx.realize_map[s] + else: + # None in the device assigns it a number later + opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL) + new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None) + if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges]) new_srcs.append(new_src) # NOTE: do we need this? return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 883fd1452b..02878d3222 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -471,6 +471,9 @@ pm_add_range_tags = PatternMatcher([ def split_store(ctx:list[UOp], x:UOp) -> UOp|None: if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None + # ends of outer range don't go in kernels + if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None + # local kernel rewrite lctx = LocalAddBufferContext() ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True) @@ -504,7 +507,7 @@ def tag_uop(ctx:list[UOp], x:UOp): return x.replace(tag=(len(ctx)-1,)) add_tags = PatternMatcher([ # don't tag BUFFERs, they are global - (UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, + (UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END, Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop), (UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)), ]) diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 2c4086fc15..71944d5698 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -40,6 +40,9 @@ shared_spec = PatternMatcher([ rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), (UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None), + + # RANGE/SPECIAL define loops, END closes them + (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True), ]) # ***** UOp spec in the Tensor graph ***** @@ -109,6 +112,10 @@ _tensor_spec = PatternMatcher([ # AFTER if things were kernelized (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), + + # Tensor range bind / store + (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True), + (UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True) ])+movement_ops+shared_spec tensor_spec = PatternMatcher([ @@ -128,9 +135,6 @@ shared_codegen_spec = PatternMatcher([ (UPat(Ops.AFTER, src=(UPat(GroupOp.Defines|{Ops.AFTER}),), allow_any_len=True), lambda: True), (UPat(Ops.GROUP, dtypes.void), lambda: True), - # RANGE/SPECIAL define loops, END closes them - (UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True), - # WMMA has a (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index d30496f349..8a943c84b3 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -64,6 +64,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]: # always exclude DEVICE/CONST/UNIQUE if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u) if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u) + if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u) for u in toposort: if u in excluded: continue argst = codecs.decode(str(u.arg), "unicode_escape")