diff --git a/test/backend/test_outerworld.py b/test/backend/test_outerworld.py deleted file mode 100644 index 266e9b1112..0000000000 --- a/test/backend/test_outerworld.py +++ /dev/null @@ -1,230 +0,0 @@ -import unittest -import numpy as np -from tinygrad import Tensor, UOp, nn -from tinygrad.uop.ops import AxisType, Ops - -class TestOuterworldReduce(unittest.TestCase): - def test_reduce(self): - x = Tensor.ones(5, 5).contiguous() - a = UOp.range(5, -1, AxisType.REDUCE) - out = x[a] - # TODO: syntax for this - t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD)) - self.assertListEqual(t.tolist(), [5.,5.,5.,5.,5.]) - -# TODO: delete test_outerworld_range? -class TestOuterRange(unittest.TestCase): - def test_simple_range(self): - a = Tensor.ones(10).contiguous() - acc = Tensor.zeros().contiguous() - Tensor.realize(a, acc) - - # this is fold - i = UOp.range(10, -100, AxisType.OUTER) - acc_i = acc.uop.after(i) - vi = UOp.variable("i", i.vmin, i.vmax).bind(i) - out = Tensor(acc.uop.after(acc_i.store(acc_i + a[vi].uop).end(i))) - out.realize() - assert out.item() == 10.0 - - def test_inner_range(self): - a = Tensor.ones(10, 10).contiguous() - acc = Tensor.zeros(10).contiguous() - Tensor.realize(a, acc) - - # this is fold - i = UOp.range(10, -100, AxisType.OUTER) - acc_i = acc.uop.after(i) - vi = UOp.variable("i", i.vmin, i.vmax).bind(i) - out = Tensor(acc.uop.after(acc_i.store(acc_i + a[:, vi].uop).end(i))) - out.realize() - self.assertEqual(out.tolist(), [10.0]*10) - - def test_range_matmul(self): - vec = Tensor.randn(1, 10).realize() - mats = Tensor.randn(3, 10, 10).realize() - - # 3 matmuls in "scan" - ref = ((vec @ mats[0]) @ mats[1]) @ mats[2] - ref.realize() - - # 3 matmuls with outer world range - i = UOp.range(3, -100, AxisType.OUTER) - vec_i = Tensor(vec.uop.after(i)) - comp = vec_i.contiguous() @ mats[i] - store = vec_i.uop.store(comp.uop).end(i) - out = Tensor(vec.uop.after(store)) - out.realize() - - # TODO: testing allclose - assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}" - -class TestOuterScan(unittest.TestCase): - def _test_scan(self): - vec = Tensor.randn(1, 10).realize() - mats = Tensor.randn(3, 10, 10).realize() - - # 3 matmuls in "scan" - vec1 = vec @ mats[0] - vec2 = vec1 @ mats[1] - vec3 = vec2 @ mats[2] - ref = Tensor.stack(vec1, vec2, vec3) - ref.realize() - return vec, mats, ref - - def test_uop_scan_matmul(self): - vec, mats, ref = self._test_scan() - - # 3 matmuls with SCAN - i = UOp.range(3, -100, AxisType.OUTER) - out = Tensor.empty(3, 1, 10) - phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) - comp = phi @ mats[i] - store = out[i].uop.store(comp.uop).end(i) - out = Tensor(out.uop.after(store)) - out.realize() - - # TODO: testing allclose - assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}" - -class TestOuterworld(unittest.TestCase): - def test_range_plus_1(self): - t = Tensor.arange(100).reshape(10,10).realize() - - # passthrough ranges - a = UOp.range(10, -1) - sel = t[a] + 1 - assert sel.shape == (10,) - cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize() - - self.assertTrue((t+1==cpy).all().item()) - - def test_range_plus_1_transpose(self): - t = Tensor.arange(100).reshape(10,10).realize() - - # passthrough ranges - a = UOp.range(10, -1) - sel = t[a] + 1 - assert sel.shape == (10,) - cpy = sel.reshape(10, 1).expand(10, a).contiguous().realize() - - self.assertTrue(((t+1).T==cpy).all().item()) - - def test_flip_range(self): - t = Tensor.rand(10, 10).realize() - - # passthrough ranges - a = UOp.range(10, -1) - sel = t[9-a] - cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize() - - self.assertTrue((t.flip(0)==cpy).all().item()) - - def test_vmap(self): - def f(x): return x.sum(axis=0)*2 - - x = Tensor.ones(3, 10, 2).contiguous() - - # vmap across axis 0 - a = UOp.range(3, -1) - out = f(x[a]) - out = out.reshape(1, 2).expand(a, 2).contiguous() - - # 3x2 grid of 20 - out.realize() - self.assertTrue((out==20).all().item()) - - def test_fancy_vmap(self): - def f(x,y): return x+y - - x = Tensor.arange(9).reshape(3,3).contiguous() - y = Tensor.arange(9).reshape(3,3).contiguous() - - a = UOp.range(3, -1) - out = f(x[:,a], y[a,:]) - # TODO: this should support flatten - out = out.reshape(1, 3).expand(a, 3).contiguous().realize() - self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist()) - -class TestVmap(unittest.TestCase): - def test_vmap_inner(self, axis_type=AxisType.LOOP, fuse=False, grad=False): - x = Tensor.ones(1, 10).contiguous().requires_grad_() - mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_() - - ref = x @ mats - if fuse: ref = ref * 2 - - # vmap across axis 0 - a = UOp.range(3, -1, axis_type) - out = x @ mats[a] - out = out.reshape(1, 10).pad(((a,(3-a)-1), None)) - out = Tensor(out.uop.reduce(a, arg=Ops.ADD)) - if fuse: out = out * 2 - if grad: - out.mean().backward() - np.testing.assert_allclose(mats.grad.numpy(), (2./30) if fuse else (1./30)) - out.realize() - - # TODO: testing allclose - assert Tensor.allclose(ref, out, atol=1e-6), f"max diff {(ref-out).abs().max().item()}" - def test_vmap_inner_fuse(self): self.test_vmap_inner(fuse=True) - def test_vmap_outer(self): self.test_vmap_inner(AxisType.OUTER) - def test_vmap_outer_fuse(self): self.test_vmap_inner(AxisType.OUTER, fuse=True) - - def test_vmap_inner_grad(self): self.test_vmap_inner(grad=True) - def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True) - def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True) - - def test_vmap_convs(self): - layers = [ - nn.Conv2d(1, 8, 3), Tensor.relu, - nn.Conv2d(8, 8, 3), Tensor.relu] - img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers)) - a = UOp.range(4, -1, AxisType.OUTER) - out = img[a:a+1].sequential(layers) - out = out.pad(((a,(4-a)-1), None, None, None)) - out = Tensor(out.uop.reduce(a, arg=Ops.ADD)) - out.realize() - np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6) - - def test_vmap_gemm(self): - layers = [ - nn.Linear(16, 16, bias=False), Tensor.relu, - nn.Linear(16, 16, bias=False), Tensor.relu] - img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers)) - a = UOp.range(4, -1, AxisType.OUTER) - out = img[a:a+1].sequential(layers) - out = out.pad(((a,(4-a)-1), None)) - out = Tensor(out.uop.reduce(a, arg=Ops.ADD)) - out.realize() - np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6) - - @unittest.skip("this is broken, we need to lower the outer reduce in the outer graph") - def test_vmap_gemm_grad(self): - layers = [ - nn.Linear(16, 16, bias=False), Tensor.relu, - nn.Linear(16, 16, bias=False), Tensor.relu] - layer_tensors = nn.state.get_parameters(layers) - img = Tensor.randn(4, 16).realize(*layer_tensors) - for l in layer_tensors: l.requires_grad_() - a = UOp.range(4, -1, AxisType.OUTER) - out = img[a:a+1].sequential(layers) - out = out.pad(((a,(4-a)-1), None)) - out = Tensor(out.uop.reduce(a, arg=Ops.ADD)) - out.mean().backward() - grads = [l.grad for l in layer_tensors] - out.realize(*grads) - out_grads = [x.numpy() for x in grads] - - # compute reference grads - for l in layer_tensors: l.grad = None - img.sequential(layers).mean().backward() - grads = [l.grad for l in layer_tensors] - out.realize(*grads) - ref_grads = [x.numpy() for x in grads] - - # compare - for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/test/backend/test_outerworld_range.py b/test/backend/test_outerworld_range.py deleted file mode 100644 index cfc610cde8..0000000000 --- a/test/backend/test_outerworld_range.py +++ /dev/null @@ -1,148 +0,0 @@ -import unittest -from tinygrad import Tensor, nn, Variable, UOp - -# outerworld range should support three things -# 1. full optimizer steps (test_model_bound_range) -# 2. gradient accumulation (you want to end the range before running the optimizer) -# 3. stacked linear layers - -class Model: - def __init__(self): self.w = nn.Linear(64, 8, bias=False) - def __call__(self, x:Tensor) -> Tensor: return self.w(x) - -def get_model_and_opt(): - Tensor.manual_seed(1337) - m = Model() - opt = nn.optim.SGD(nn.state.get_parameters(m), lr=0.1, weight_decay=0) - return m, opt - -class TestOuterworldRange(unittest.TestCase): - STEPS = 5 - BS = 20 - - @classmethod - def setUpClass(cls): - Tensor.manual_seed(1338) - # it learns to compute mean - cls.X = Tensor.randn(cls.STEPS, cls.BS, 64).contiguous().realize() - cls.Y = cls.X.reshape(cls.STEPS, cls.BS, 8, 8).mean(axis=-1).contiguous().realize() - cls.losses = cls._get_model_baseline() - - def _compare(self, losses): - for i,(x,y) in enumerate(zip(self.losses, losses)): - self.assertAlmostEqual(x, y, places=5, msg=f"mismatch at {i} in {self.losses} vs {losses}") - - @classmethod - @Tensor.train() - def _get_model_baseline(self): - m, opt = get_model_and_opt() - losses = [] - for i in range(self.STEPS): - opt.zero_grad() - loss = (m(self.X[i]) - self.Y[i]).square().mean() - loss.backward() - loss.realize(*opt.schedule_step()) - losses.append(loss.item()) - return losses - - @Tensor.train() - def test_model_grad_acc(self): - m, opt = get_model_and_opt() - losses = [] - for i in range(self.STEPS): - opt.zero_grad() - sub_batch_size = self.BS//2 - loss = 0 - scaling_factor = self.BS//sub_batch_size - for j in range(0, self.BS, sub_batch_size): - sub_loss = (m(self.X[i][j:j+sub_batch_size]) - self.Y[i][j:j+sub_batch_size]).square().mean() / scaling_factor - sub_loss.backward() - loss += sub_loss - loss.realize(*opt.schedule_step()) - losses.append(loss.item()) - self._compare(losses) - - @Tensor.train() - def test_model_variable(self): - m, opt = get_model_and_opt() - losses = [] - vi = Variable('i', 0, self.STEPS-1) - for i in range(self.STEPS): - vib = vi.bind(i) - opt.zero_grad() - loss = (m(self.X[vib]) - self.Y[vib]).square().mean() - loss.backward() - loss.realize(*opt.schedule_step()) - losses.append(loss.item()) - self._compare(losses) - - @Tensor.train() - def test_model_scheduled(self): - m, opt = get_model_and_opt() - losses = [] - for i in range(self.STEPS): - opt.zero_grad() - loss = (m(self.X[i]) - self.Y[i]).square().mean() - loss.backward() - opt.schedule_step() - losses.append(loss) - self._compare(Tensor.stack(*losses).tolist()) - - @Tensor.train() - def test_model_scheduled_setitem(self): - m, opt = get_model_and_opt() - losses = Tensor.empty(self.STEPS) - for i in range(self.STEPS): - opt.zero_grad() - loss = (m(self.X[i]) - self.Y[i]).square().mean() - loss.backward() - opt.schedule_step() - # TODO: this shouldn't realize - losses[i] = loss.requires_grad_(False) - self._compare(losses.tolist()) - - @unittest.expectedFailure - @Tensor.train() - def test_model_scheduled_variable(self): - m, opt = get_model_and_opt() - losses = [] - vi = Variable('i', 0, self.STEPS-1) - for i in range(self.STEPS): - vib = vi.bind(i) - opt.zero_grad() - loss = (m(self.X[vib]) - self.Y[vib]).square().mean() - loss.backward() - opt.schedule_step() - losses.append(loss) - self._compare(Tensor.stack(*losses).tolist()) - - @unittest.expectedFailure - @Tensor.train() - def test_model_scheduled_variable_setitem(self): - m, opt = get_model_and_opt() - losses = Tensor.empty(self.STEPS) - vi = Variable('i', 0, self.STEPS-1) - for i in range(self.STEPS): - vib = vi.bind(i) - opt.zero_grad() - loss = (m(self.X[vib]) - self.Y[vib]).square().mean() - loss.backward() - opt.schedule_step() - losses[vib] = loss.requires_grad_(False) - self._compare(losses.tolist()) - - @unittest.expectedFailure - @Tensor.train() - def test_model_bound_range(self): - m, opt = get_model_and_opt() - # TODO: should ranges be unique so you don't have to pass in the -1? - rng = UOp.range(self.STEPS, -1) - vib = Variable('i', 0, self.STEPS-1).bind(rng) - loss = (m(self.X[vib]) - self.Y[vib]).square().mean() - loss.backward() - losses = Tensor.empty(self.STEPS) - losses[vib] = loss - losses.realize(*opt.schedule_step()) - -if __name__ == "__main__": - unittest.main() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 8b7a7d4204..f542f80b49 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -4,14 +4,11 @@ from collections import deque from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata +from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE from tinygrad.engine.realize import ExecItem # **** schedule linearizer -# ScheduleItem = tuple[AST, buffer UOps, metadata, bound_ranges] -ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], tuple[UOp, ...]] - # unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op) def _unwrap_src(s: UOp) -> UOp: while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0] @@ -23,9 +20,8 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: children: dict[UOp, list[UOp]] = {} in_degree: dict[UOp, int] = {} for u in sched_sink.toposort(gate_kernel_sink): - if u.op is Ops.RANGE: in_degree.setdefault(u, 0) if u.op is not Ops.AFTER: continue - if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph + k = u.src[1] assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}" in_degree.setdefault(k, 0) if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}" @@ -50,52 +46,21 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: with cpu_profile(TracingKey("linearize schedule")): queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0) - - schedule: list[UOp] = [] # RANGE, KERNEL, or END UOps - sched_item: dict[UOp, ScheduleItem] = {} + pre_schedule: list[ExecItem] = [] + buf_uops_list: list[UOp] = [] while len(queue): - k = rk = queue.popleft() - if k.op is Ops.END: k = k.src[0] - assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}" - if k.op is Ops.RANGE: schedule.append(k) - elif k.op is Ops.CALL: - ast = k.src[0] - buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND) - bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE) - sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges) - schedule.append(k) - if rk.op is Ops.END: schedule.append(rk) + rk = queue.popleft() + k = rk.src[0] if rk.op is Ops.END else rk + assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}" + buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND) + pre_schedule.append(ExecItem(k.src[0], [], k.arg.metadata)) + buf_uops_list.append(UOp.sink(*buf_uops)) for x in children.get(rk, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) - with cpu_profile(TracingKey("unroll outer ranges")): - pre_schedule, buf_uops_list = unroll_outer_ranges(schedule, sched_item) return pre_schedule, UOp.sink(*buf_uops_list) -def unroll_outer_ranges(schedule:list[UOp], sched_item:dict[UOp, ScheduleItem]) -> tuple[list[ExecItem], list[UOp]]: - pre_schedule: list[ExecItem] = [] - buf_uops_list: list[UOp] = [] - sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]() - while sched_ptr < len(schedule): - si = schedule[sched_ptr] - if si.op is Ops.RANGE: - in_ranges[si] = 0 - range_ptrs[si] = sched_ptr + 1 - elif si.op is Ops.END: - if in_ranges[si.src[1]] < si.src[1].vmax: - in_ranges[si.src[1]] += 1 - sched_ptr = range_ptrs[si.src[1]] - continue - else: - assert si.op is Ops.CALL, f"unexpected op in schedule: {si.op}" - ast, buf_uops, metadata, bound_ranges = sched_item[si] - fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges} - pre_schedule.append(ExecItem(ast, [], metadata, fixedvars)) - buf_uops_list.append(UOp.sink(*buf_uops)) - sched_ptr += 1 - return pre_schedule, buf_uops_list - from tinygrad.engine.memory import memory_planner from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.multi import get_multi_map diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index a8b0a7327d..64924138ed 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -39,7 +39,6 @@ pm_gradient = PatternMatcher([ (UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)), (UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))), (UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])), - (UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)), (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)), (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)), (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)), diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 444397f182..bbb21fc76f 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -26,8 +26,6 @@ pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([ (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), # always realize (UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize), - # always realize REDUCE on outer ranges - (UPat(Ops.REDUCE, name="r"), lambda ctx,r: realize(ctx, r) if any(tr.arg[-1] == AxisType.OUTER for tr in r.src[1:]) else None), # realize srcs of these (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs), # sometimes realize src of assign diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index a95c69f719..e9f0831aef 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink -from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str +from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS from tinygrad.helpers import PCONTIG, partition, get_single_element @@ -363,18 +363,6 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True): for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg) return ret - # lower outerworld reduce here - if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER: - assert sdtype.addrspace == AddrSpace.GLOBAL - outer_range = x.src[0].src[1] - buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size) - # NOTE: this has the same number as the outer range, we need string ranges! - zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,)) - buf = buf.after(buf.index(zero_range).store(0).end(zero_range)) - bufi = buf.index(idx, dtype=sdtype) - do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range) - return buf.after(do_store) - # NOTE: the DEFINE_LOCAL needs to be disambiguated here if sdtype.addrspace == AddrSpace.GLOBAL: buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size) @@ -447,9 +435,6 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp): def renumber_range(ctx:LocalAddBufferContext, r:UOp): if r.tag != (): return None - if r.arg[-1] == AxisType.OUTER: - # for outer range, we replace with a bound variable - return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None)) ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None) ctx.range += 1 return ret @@ -519,11 +504,8 @@ pm_add_range_tags = PatternMatcher([ ]) def split_store(ctx:list[UOp], x:UOp) -> UOp|None: - # if we have any non-outer ranges open here, we don't split - if any(r.arg[-1] != AxisType.OUTER for r in x.ranges): return None - - # ends of outer range don't go in kernels - if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None + # if we have any open ranges here, we don't split + if x.ranges: return None # local kernel rewrite lctx = LocalAddBufferContext() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 931141d331..83925bc7b7 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -16,16 +16,15 @@ if TYPE_CHECKING: class AxisType(Enum): def __repr__(self): return str(self) GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 - THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702 + THREAD = auto(); PLACEHOLDER = auto() # noqa: E702 axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u", - AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"} + AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE", - AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta", - AxisType.OUTER: "green"} + AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} # NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3, - AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2} + AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5} range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1} diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 1f1556bc3e..dc7122ab71 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -123,16 +123,9 @@ _tensor_spec = PatternMatcher([ # REDUCE_AXIS is the reduce in the tensor graph (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), - # REDUCE with an outerworld range - (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), - # AFTER if things were kernelized (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), - # Tensor range bind / store - (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True), - (UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True), - # allow CALL/PARAM (UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype), (UPat(Ops.PARAM), lambda: True),