remove all the outerworld stuff, it was too complex (#14852)

This commit is contained in:
George Hotz
2026-02-18 17:44:11 +08:00
committed by GitHub
parent 6d301ad2c4
commit af839b2bd1
8 changed files with 17 additions and 459 deletions

View File

@@ -1,230 +0,0 @@
import unittest
import numpy as np
from tinygrad import Tensor, UOp, nn
from tinygrad.uop.ops import AxisType, Ops
class TestOuterworldReduce(unittest.TestCase):
def test_reduce(self):
x = Tensor.ones(5, 5).contiguous()
a = UOp.range(5, -1, AxisType.REDUCE)
out = x[a]
# TODO: syntax for this
t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
self.assertListEqual(t.tolist(), [5.,5.,5.,5.,5.])
# TODO: delete test_outerworld_range?
class TestOuterRange(unittest.TestCase):
def test_simple_range(self):
a = Tensor.ones(10).contiguous()
acc = Tensor.zeros().contiguous()
Tensor.realize(a, acc)
# this is fold
i = UOp.range(10, -100, AxisType.OUTER)
acc_i = acc.uop.after(i)
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[vi].uop).end(i)))
out.realize()
assert out.item() == 10.0
def test_inner_range(self):
a = Tensor.ones(10, 10).contiguous()
acc = Tensor.zeros(10).contiguous()
Tensor.realize(a, acc)
# this is fold
i = UOp.range(10, -100, AxisType.OUTER)
acc_i = acc.uop.after(i)
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[:, vi].uop).end(i)))
out.realize()
self.assertEqual(out.tolist(), [10.0]*10)
def test_range_matmul(self):
vec = Tensor.randn(1, 10).realize()
mats = Tensor.randn(3, 10, 10).realize()
# 3 matmuls in "scan"
ref = ((vec @ mats[0]) @ mats[1]) @ mats[2]
ref.realize()
# 3 matmuls with outer world range
i = UOp.range(3, -100, AxisType.OUTER)
vec_i = Tensor(vec.uop.after(i))
comp = vec_i.contiguous() @ mats[i]
store = vec_i.uop.store(comp.uop).end(i)
out = Tensor(vec.uop.after(store))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}"
class TestOuterScan(unittest.TestCase):
def _test_scan(self):
vec = Tensor.randn(1, 10).realize()
mats = Tensor.randn(3, 10, 10).realize()
# 3 matmuls in "scan"
vec1 = vec @ mats[0]
vec2 = vec1 @ mats[1]
vec3 = vec2 @ mats[2]
ref = Tensor.stack(vec1, vec2, vec3)
ref.realize()
return vec, mats, ref
def test_uop_scan_matmul(self):
vec, mats, ref = self._test_scan()
# 3 matmuls with SCAN
i = UOp.range(3, -100, AxisType.OUTER)
out = Tensor.empty(3, 1, 10)
phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
comp = phi @ mats[i]
store = out[i].uop.store(comp.uop).end(i)
out = Tensor(out.uop.after(store))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}"
class TestOuterworld(unittest.TestCase):
def test_range_plus_1(self):
t = Tensor.arange(100).reshape(10,10).realize()
# passthrough ranges
a = UOp.range(10, -1)
sel = t[a] + 1
assert sel.shape == (10,)
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
self.assertTrue((t+1==cpy).all().item())
def test_range_plus_1_transpose(self):
t = Tensor.arange(100).reshape(10,10).realize()
# passthrough ranges
a = UOp.range(10, -1)
sel = t[a] + 1
assert sel.shape == (10,)
cpy = sel.reshape(10, 1).expand(10, a).contiguous().realize()
self.assertTrue(((t+1).T==cpy).all().item())
def test_flip_range(self):
t = Tensor.rand(10, 10).realize()
# passthrough ranges
a = UOp.range(10, -1)
sel = t[9-a]
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
self.assertTrue((t.flip(0)==cpy).all().item())
def test_vmap(self):
def f(x): return x.sum(axis=0)*2
x = Tensor.ones(3, 10, 2).contiguous()
# vmap across axis 0
a = UOp.range(3, -1)
out = f(x[a])
out = out.reshape(1, 2).expand(a, 2).contiguous()
# 3x2 grid of 20
out.realize()
self.assertTrue((out==20).all().item())
def test_fancy_vmap(self):
def f(x,y): return x+y
x = Tensor.arange(9).reshape(3,3).contiguous()
y = Tensor.arange(9).reshape(3,3).contiguous()
a = UOp.range(3, -1)
out = f(x[:,a], y[a,:])
# TODO: this should support flatten
out = out.reshape(1, 3).expand(a, 3).contiguous().realize()
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
class TestVmap(unittest.TestCase):
def test_vmap_inner(self, axis_type=AxisType.LOOP, fuse=False, grad=False):
x = Tensor.ones(1, 10).contiguous().requires_grad_()
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
ref = x @ mats
if fuse: ref = ref * 2
# vmap across axis 0
a = UOp.range(3, -1, axis_type)
out = x @ mats[a]
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
if fuse: out = out * 2
if grad:
out.mean().backward()
np.testing.assert_allclose(mats.grad.numpy(), (2./30) if fuse else (1./30))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-6), f"max diff {(ref-out).abs().max().item()}"
def test_vmap_inner_fuse(self): self.test_vmap_inner(fuse=True)
def test_vmap_outer(self): self.test_vmap_inner(AxisType.OUTER)
def test_vmap_outer_fuse(self): self.test_vmap_inner(AxisType.OUTER, fuse=True)
def test_vmap_inner_grad(self): self.test_vmap_inner(grad=True)
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
def test_vmap_convs(self):
layers = [
nn.Conv2d(1, 8, 3), Tensor.relu,
nn.Conv2d(8, 8, 3), Tensor.relu]
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None, None, None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
def test_vmap_gemm(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
def test_vmap_gemm_grad(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
layer_tensors = nn.state.get_parameters(layers)
img = Tensor.randn(4, 16).realize(*layer_tensors)
for l in layer_tensors: l.requires_grad_()
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
out_grads = [x.numpy() for x in grads]
# compute reference grads
for l in layer_tensors: l.grad = None
img.sequential(layers).mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
ref_grads = [x.numpy() for x in grads]
# compare
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,148 +0,0 @@
import unittest
from tinygrad import Tensor, nn, Variable, UOp
# outerworld range should support three things
# 1. full optimizer steps (test_model_bound_range)
# 2. gradient accumulation (you want to end the range before running the optimizer)
# 3. stacked linear layers
class Model:
def __init__(self): self.w = nn.Linear(64, 8, bias=False)
def __call__(self, x:Tensor) -> Tensor: return self.w(x)
def get_model_and_opt():
Tensor.manual_seed(1337)
m = Model()
opt = nn.optim.SGD(nn.state.get_parameters(m), lr=0.1, weight_decay=0)
return m, opt
class TestOuterworldRange(unittest.TestCase):
STEPS = 5
BS = 20
@classmethod
def setUpClass(cls):
Tensor.manual_seed(1338)
# it learns to compute mean
cls.X = Tensor.randn(cls.STEPS, cls.BS, 64).contiguous().realize()
cls.Y = cls.X.reshape(cls.STEPS, cls.BS, 8, 8).mean(axis=-1).contiguous().realize()
cls.losses = cls._get_model_baseline()
def _compare(self, losses):
for i,(x,y) in enumerate(zip(self.losses, losses)):
self.assertAlmostEqual(x, y, places=5, msg=f"mismatch at {i} in {self.losses} vs {losses}")
@classmethod
@Tensor.train()
def _get_model_baseline(self):
m, opt = get_model_and_opt()
losses = []
for i in range(self.STEPS):
opt.zero_grad()
loss = (m(self.X[i]) - self.Y[i]).square().mean()
loss.backward()
loss.realize(*opt.schedule_step())
losses.append(loss.item())
return losses
@Tensor.train()
def test_model_grad_acc(self):
m, opt = get_model_and_opt()
losses = []
for i in range(self.STEPS):
opt.zero_grad()
sub_batch_size = self.BS//2
loss = 0
scaling_factor = self.BS//sub_batch_size
for j in range(0, self.BS, sub_batch_size):
sub_loss = (m(self.X[i][j:j+sub_batch_size]) - self.Y[i][j:j+sub_batch_size]).square().mean() / scaling_factor
sub_loss.backward()
loss += sub_loss
loss.realize(*opt.schedule_step())
losses.append(loss.item())
self._compare(losses)
@Tensor.train()
def test_model_variable(self):
m, opt = get_model_and_opt()
losses = []
vi = Variable('i', 0, self.STEPS-1)
for i in range(self.STEPS):
vib = vi.bind(i)
opt.zero_grad()
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
loss.backward()
loss.realize(*opt.schedule_step())
losses.append(loss.item())
self._compare(losses)
@Tensor.train()
def test_model_scheduled(self):
m, opt = get_model_and_opt()
losses = []
for i in range(self.STEPS):
opt.zero_grad()
loss = (m(self.X[i]) - self.Y[i]).square().mean()
loss.backward()
opt.schedule_step()
losses.append(loss)
self._compare(Tensor.stack(*losses).tolist())
@Tensor.train()
def test_model_scheduled_setitem(self):
m, opt = get_model_and_opt()
losses = Tensor.empty(self.STEPS)
for i in range(self.STEPS):
opt.zero_grad()
loss = (m(self.X[i]) - self.Y[i]).square().mean()
loss.backward()
opt.schedule_step()
# TODO: this shouldn't realize
losses[i] = loss.requires_grad_(False)
self._compare(losses.tolist())
@unittest.expectedFailure
@Tensor.train()
def test_model_scheduled_variable(self):
m, opt = get_model_and_opt()
losses = []
vi = Variable('i', 0, self.STEPS-1)
for i in range(self.STEPS):
vib = vi.bind(i)
opt.zero_grad()
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
loss.backward()
opt.schedule_step()
losses.append(loss)
self._compare(Tensor.stack(*losses).tolist())
@unittest.expectedFailure
@Tensor.train()
def test_model_scheduled_variable_setitem(self):
m, opt = get_model_and_opt()
losses = Tensor.empty(self.STEPS)
vi = Variable('i', 0, self.STEPS-1)
for i in range(self.STEPS):
vib = vi.bind(i)
opt.zero_grad()
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
loss.backward()
opt.schedule_step()
losses[vib] = loss.requires_grad_(False)
self._compare(losses.tolist())
@unittest.expectedFailure
@Tensor.train()
def test_model_bound_range(self):
m, opt = get_model_and_opt()
# TODO: should ranges be unique so you don't have to pass in the -1?
rng = UOp.range(self.STEPS, -1)
vib = Variable('i', 0, self.STEPS-1).bind(rng)
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
loss.backward()
losses = Tensor.empty(self.STEPS)
losses[vib] = loss
losses.realize(*opt.schedule_step())
if __name__ == "__main__":
unittest.main()

View File

@@ -4,14 +4,11 @@ from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
from tinygrad.engine.realize import ExecItem
# **** schedule linearizer
# ScheduleItem = tuple[AST, buffer UOps, metadata, bound_ranges]
ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], tuple[UOp, ...]]
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
def _unwrap_src(s: UOp) -> UOp:
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
@@ -23,9 +20,8 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort(gate_kernel_sink):
if u.op is Ops.RANGE: in_degree.setdefault(u, 0)
if u.op is not Ops.AFTER: continue
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
k = u.src[1]
assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
in_degree.setdefault(k, 0)
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
@@ -50,52 +46,21 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
with cpu_profile(TracingKey("linearize schedule")):
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
schedule: list[UOp] = [] # RANGE, KERNEL, or END UOps
sched_item: dict[UOp, ScheduleItem] = {}
pre_schedule: list[ExecItem] = []
buf_uops_list: list[UOp] = []
while len(queue):
k = rk = queue.popleft()
if k.op is Ops.END: k = k.src[0]
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
if k.op is Ops.RANGE: schedule.append(k)
elif k.op is Ops.CALL:
ast = k.src[0]
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges)
schedule.append(k)
if rk.op is Ops.END: schedule.append(rk)
rk = queue.popleft()
k = rk.src[0] if rk.op is Ops.END else rk
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
pre_schedule.append(ExecItem(k.src[0], [], k.arg.metadata))
buf_uops_list.append(UOp.sink(*buf_uops))
for x in children.get(rk, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
with cpu_profile(TracingKey("unroll outer ranges")):
pre_schedule, buf_uops_list = unroll_outer_ranges(schedule, sched_item)
return pre_schedule, UOp.sink(*buf_uops_list)
def unroll_outer_ranges(schedule:list[UOp], sched_item:dict[UOp, ScheduleItem]) -> tuple[list[ExecItem], list[UOp]]:
pre_schedule: list[ExecItem] = []
buf_uops_list: list[UOp] = []
sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]()
while sched_ptr < len(schedule):
si = schedule[sched_ptr]
if si.op is Ops.RANGE:
in_ranges[si] = 0
range_ptrs[si] = sched_ptr + 1
elif si.op is Ops.END:
if in_ranges[si.src[1]] < si.src[1].vmax:
in_ranges[si.src[1]] += 1
sched_ptr = range_ptrs[si.src[1]]
continue
else:
assert si.op is Ops.CALL, f"unexpected op in schedule: {si.op}"
ast, buf_uops, metadata, bound_ranges = sched_item[si]
fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
buf_uops_list.append(UOp.sink(*buf_uops))
sched_ptr += 1
return pre_schedule, buf_uops_list
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map

View File

@@ -39,7 +39,6 @@ pm_gradient = PatternMatcher([
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
(UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)),
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),

View File

@@ -26,8 +26,6 @@ pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize),
# always realize REDUCE on outer ranges
(UPat(Ops.REDUCE, name="r"), lambda ctx,r: realize(ctx, r) if any(tr.arg[-1] == AxisType.OUTER for tr in r.src[1:]) else None),
# realize srcs of these
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs),
# sometimes realize src of assign

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import PCONTIG, partition, get_single_element
@@ -363,18 +363,6 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
return ret
# lower outerworld reduce here
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
assert sdtype.addrspace == AddrSpace.GLOBAL
outer_range = x.src[0].src[1]
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
# NOTE: this has the same number as the outer range, we need string ranges!
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
bufi = buf.index(idx, dtype=sdtype)
do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range)
return buf.after(do_store)
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
@@ -447,9 +435,6 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.tag != (): return None
if r.arg[-1] == AxisType.OUTER:
# for outer range, we replace with a bound variable
return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None))
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
ctx.range += 1
return ret
@@ -519,11 +504,8 @@ pm_add_range_tags = PatternMatcher([
])
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
# if we have any non-outer ranges open here, we don't split
if any(r.arg[-1] != AxisType.OUTER for r in x.ranges): return None
# ends of outer range don't go in kernels
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
# if we have any open ranges here, we don't split
if x.ranges: return None
# local kernel rewrite
lctx = LocalAddBufferContext()

View File

@@ -16,16 +16,15 @@ if TYPE_CHECKING:
class AxisType(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702
THREAD = auto(); PLACEHOLDER = auto() # noqa: E702
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
AxisType.OUTER: "green"}
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2}
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1}

View File

@@ -123,16 +123,9 @@ _tensor_spec = PatternMatcher([
# REDUCE_AXIS is the reduce in the tensor graph
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
# REDUCE with an outerworld range
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
# AFTER if things were kernelized
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
# Tensor range bind / store
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True),
# allow CALL/PARAM
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
(UPat(Ops.PARAM), lambda: True),