mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove all the outerworld stuff, it was too complex (#14852)
This commit is contained in:
@@ -1,230 +0,0 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, UOp, nn
|
||||
from tinygrad.uop.ops import AxisType, Ops
|
||||
|
||||
class TestOuterworldReduce(unittest.TestCase):
|
||||
def test_reduce(self):
|
||||
x = Tensor.ones(5, 5).contiguous()
|
||||
a = UOp.range(5, -1, AxisType.REDUCE)
|
||||
out = x[a]
|
||||
# TODO: syntax for this
|
||||
t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
|
||||
self.assertListEqual(t.tolist(), [5.,5.,5.,5.,5.])
|
||||
|
||||
# TODO: delete test_outerworld_range?
|
||||
class TestOuterRange(unittest.TestCase):
|
||||
def test_simple_range(self):
|
||||
a = Tensor.ones(10).contiguous()
|
||||
acc = Tensor.zeros().contiguous()
|
||||
Tensor.realize(a, acc)
|
||||
|
||||
# this is fold
|
||||
i = UOp.range(10, -100, AxisType.OUTER)
|
||||
acc_i = acc.uop.after(i)
|
||||
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[vi].uop).end(i)))
|
||||
out.realize()
|
||||
assert out.item() == 10.0
|
||||
|
||||
def test_inner_range(self):
|
||||
a = Tensor.ones(10, 10).contiguous()
|
||||
acc = Tensor.zeros(10).contiguous()
|
||||
Tensor.realize(a, acc)
|
||||
|
||||
# this is fold
|
||||
i = UOp.range(10, -100, AxisType.OUTER)
|
||||
acc_i = acc.uop.after(i)
|
||||
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[:, vi].uop).end(i)))
|
||||
out.realize()
|
||||
self.assertEqual(out.tolist(), [10.0]*10)
|
||||
|
||||
def test_range_matmul(self):
|
||||
vec = Tensor.randn(1, 10).realize()
|
||||
mats = Tensor.randn(3, 10, 10).realize()
|
||||
|
||||
# 3 matmuls in "scan"
|
||||
ref = ((vec @ mats[0]) @ mats[1]) @ mats[2]
|
||||
ref.realize()
|
||||
|
||||
# 3 matmuls with outer world range
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
vec_i = Tensor(vec.uop.after(i))
|
||||
comp = vec_i.contiguous() @ mats[i]
|
||||
store = vec_i.uop.store(comp.uop).end(i)
|
||||
out = Tensor(vec.uop.after(store))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}"
|
||||
|
||||
class TestOuterScan(unittest.TestCase):
|
||||
def _test_scan(self):
|
||||
vec = Tensor.randn(1, 10).realize()
|
||||
mats = Tensor.randn(3, 10, 10).realize()
|
||||
|
||||
# 3 matmuls in "scan"
|
||||
vec1 = vec @ mats[0]
|
||||
vec2 = vec1 @ mats[1]
|
||||
vec3 = vec2 @ mats[2]
|
||||
ref = Tensor.stack(vec1, vec2, vec3)
|
||||
ref.realize()
|
||||
return vec, mats, ref
|
||||
|
||||
def test_uop_scan_matmul(self):
|
||||
vec, mats, ref = self._test_scan()
|
||||
|
||||
# 3 matmuls with SCAN
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
out = Tensor.empty(3, 1, 10)
|
||||
phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
|
||||
comp = phi @ mats[i]
|
||||
store = out[i].uop.store(comp.uop).end(i)
|
||||
out = Tensor(out.uop.after(store))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}"
|
||||
|
||||
class TestOuterworld(unittest.TestCase):
|
||||
def test_range_plus_1(self):
|
||||
t = Tensor.arange(100).reshape(10,10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[a] + 1
|
||||
assert sel.shape == (10,)
|
||||
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
|
||||
|
||||
self.assertTrue((t+1==cpy).all().item())
|
||||
|
||||
def test_range_plus_1_transpose(self):
|
||||
t = Tensor.arange(100).reshape(10,10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[a] + 1
|
||||
assert sel.shape == (10,)
|
||||
cpy = sel.reshape(10, 1).expand(10, a).contiguous().realize()
|
||||
|
||||
self.assertTrue(((t+1).T==cpy).all().item())
|
||||
|
||||
def test_flip_range(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[9-a]
|
||||
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
|
||||
|
||||
self.assertTrue((t.flip(0)==cpy).all().item())
|
||||
|
||||
def test_vmap(self):
|
||||
def f(x): return x.sum(axis=0)*2
|
||||
|
||||
x = Tensor.ones(3, 10, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1)
|
||||
out = f(x[a])
|
||||
out = out.reshape(1, 2).expand(a, 2).contiguous()
|
||||
|
||||
# 3x2 grid of 20
|
||||
out.realize()
|
||||
self.assertTrue((out==20).all().item())
|
||||
|
||||
def test_fancy_vmap(self):
|
||||
def f(x,y): return x+y
|
||||
|
||||
x = Tensor.arange(9).reshape(3,3).contiguous()
|
||||
y = Tensor.arange(9).reshape(3,3).contiguous()
|
||||
|
||||
a = UOp.range(3, -1)
|
||||
out = f(x[:,a], y[a,:])
|
||||
# TODO: this should support flatten
|
||||
out = out.reshape(1, 3).expand(a, 3).contiguous().realize()
|
||||
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
|
||||
|
||||
class TestVmap(unittest.TestCase):
|
||||
def test_vmap_inner(self, axis_type=AxisType.LOOP, fuse=False, grad=False):
|
||||
x = Tensor.ones(1, 10).contiguous().requires_grad_()
|
||||
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
|
||||
|
||||
ref = x @ mats
|
||||
if fuse: ref = ref * 2
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1, axis_type)
|
||||
out = x @ mats[a]
|
||||
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
if fuse: out = out * 2
|
||||
if grad:
|
||||
out.mean().backward()
|
||||
np.testing.assert_allclose(mats.grad.numpy(), (2./30) if fuse else (1./30))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-6), f"max diff {(ref-out).abs().max().item()}"
|
||||
def test_vmap_inner_fuse(self): self.test_vmap_inner(fuse=True)
|
||||
def test_vmap_outer(self): self.test_vmap_inner(AxisType.OUTER)
|
||||
def test_vmap_outer_fuse(self): self.test_vmap_inner(AxisType.OUTER, fuse=True)
|
||||
|
||||
def test_vmap_inner_grad(self): self.test_vmap_inner(grad=True)
|
||||
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
|
||||
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
|
||||
|
||||
def test_vmap_convs(self):
|
||||
layers = [
|
||||
nn.Conv2d(1, 8, 3), Tensor.relu,
|
||||
nn.Conv2d(8, 8, 3), Tensor.relu]
|
||||
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None, None, None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.realize()
|
||||
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||
|
||||
def test_vmap_gemm(self):
|
||||
layers = [
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.realize()
|
||||
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||
|
||||
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
|
||||
def test_vmap_gemm_grad(self):
|
||||
layers = [
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||
layer_tensors = nn.state.get_parameters(layers)
|
||||
img = Tensor.randn(4, 16).realize(*layer_tensors)
|
||||
for l in layer_tensors: l.requires_grad_()
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.mean().backward()
|
||||
grads = [l.grad for l in layer_tensors]
|
||||
out.realize(*grads)
|
||||
out_grads = [x.numpy() for x in grads]
|
||||
|
||||
# compute reference grads
|
||||
for l in layer_tensors: l.grad = None
|
||||
img.sequential(layers).mean().backward()
|
||||
grads = [l.grad for l in layer_tensors]
|
||||
out.realize(*grads)
|
||||
ref_grads = [x.numpy() for x in grads]
|
||||
|
||||
# compare
|
||||
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,148 +0,0 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, nn, Variable, UOp
|
||||
|
||||
# outerworld range should support three things
|
||||
# 1. full optimizer steps (test_model_bound_range)
|
||||
# 2. gradient accumulation (you want to end the range before running the optimizer)
|
||||
# 3. stacked linear layers
|
||||
|
||||
class Model:
|
||||
def __init__(self): self.w = nn.Linear(64, 8, bias=False)
|
||||
def __call__(self, x:Tensor) -> Tensor: return self.w(x)
|
||||
|
||||
def get_model_and_opt():
|
||||
Tensor.manual_seed(1337)
|
||||
m = Model()
|
||||
opt = nn.optim.SGD(nn.state.get_parameters(m), lr=0.1, weight_decay=0)
|
||||
return m, opt
|
||||
|
||||
class TestOuterworldRange(unittest.TestCase):
|
||||
STEPS = 5
|
||||
BS = 20
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
Tensor.manual_seed(1338)
|
||||
# it learns to compute mean
|
||||
cls.X = Tensor.randn(cls.STEPS, cls.BS, 64).contiguous().realize()
|
||||
cls.Y = cls.X.reshape(cls.STEPS, cls.BS, 8, 8).mean(axis=-1).contiguous().realize()
|
||||
cls.losses = cls._get_model_baseline()
|
||||
|
||||
def _compare(self, losses):
|
||||
for i,(x,y) in enumerate(zip(self.losses, losses)):
|
||||
self.assertAlmostEqual(x, y, places=5, msg=f"mismatch at {i} in {self.losses} vs {losses}")
|
||||
|
||||
@classmethod
|
||||
@Tensor.train()
|
||||
def _get_model_baseline(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
for i in range(self.STEPS):
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[i]) - self.Y[i]).square().mean()
|
||||
loss.backward()
|
||||
loss.realize(*opt.schedule_step())
|
||||
losses.append(loss.item())
|
||||
return losses
|
||||
|
||||
@Tensor.train()
|
||||
def test_model_grad_acc(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
for i in range(self.STEPS):
|
||||
opt.zero_grad()
|
||||
sub_batch_size = self.BS//2
|
||||
loss = 0
|
||||
scaling_factor = self.BS//sub_batch_size
|
||||
for j in range(0, self.BS, sub_batch_size):
|
||||
sub_loss = (m(self.X[i][j:j+sub_batch_size]) - self.Y[i][j:j+sub_batch_size]).square().mean() / scaling_factor
|
||||
sub_loss.backward()
|
||||
loss += sub_loss
|
||||
loss.realize(*opt.schedule_step())
|
||||
losses.append(loss.item())
|
||||
self._compare(losses)
|
||||
|
||||
@Tensor.train()
|
||||
def test_model_variable(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
vi = Variable('i', 0, self.STEPS-1)
|
||||
for i in range(self.STEPS):
|
||||
vib = vi.bind(i)
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||
loss.backward()
|
||||
loss.realize(*opt.schedule_step())
|
||||
losses.append(loss.item())
|
||||
self._compare(losses)
|
||||
|
||||
@Tensor.train()
|
||||
def test_model_scheduled(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
for i in range(self.STEPS):
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[i]) - self.Y[i]).square().mean()
|
||||
loss.backward()
|
||||
opt.schedule_step()
|
||||
losses.append(loss)
|
||||
self._compare(Tensor.stack(*losses).tolist())
|
||||
|
||||
@Tensor.train()
|
||||
def test_model_scheduled_setitem(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = Tensor.empty(self.STEPS)
|
||||
for i in range(self.STEPS):
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[i]) - self.Y[i]).square().mean()
|
||||
loss.backward()
|
||||
opt.schedule_step()
|
||||
# TODO: this shouldn't realize
|
||||
losses[i] = loss.requires_grad_(False)
|
||||
self._compare(losses.tolist())
|
||||
|
||||
@unittest.expectedFailure
|
||||
@Tensor.train()
|
||||
def test_model_scheduled_variable(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
vi = Variable('i', 0, self.STEPS-1)
|
||||
for i in range(self.STEPS):
|
||||
vib = vi.bind(i)
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||
loss.backward()
|
||||
opt.schedule_step()
|
||||
losses.append(loss)
|
||||
self._compare(Tensor.stack(*losses).tolist())
|
||||
|
||||
@unittest.expectedFailure
|
||||
@Tensor.train()
|
||||
def test_model_scheduled_variable_setitem(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = Tensor.empty(self.STEPS)
|
||||
vi = Variable('i', 0, self.STEPS-1)
|
||||
for i in range(self.STEPS):
|
||||
vib = vi.bind(i)
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||
loss.backward()
|
||||
opt.schedule_step()
|
||||
losses[vib] = loss.requires_grad_(False)
|
||||
self._compare(losses.tolist())
|
||||
|
||||
@unittest.expectedFailure
|
||||
@Tensor.train()
|
||||
def test_model_bound_range(self):
|
||||
m, opt = get_model_and_opt()
|
||||
# TODO: should ranges be unique so you don't have to pass in the -1?
|
||||
rng = UOp.range(self.STEPS, -1)
|
||||
vib = Variable('i', 0, self.STEPS-1).bind(rng)
|
||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||
loss.backward()
|
||||
losses = Tensor.empty(self.STEPS)
|
||||
losses[vib] = loss
|
||||
losses.realize(*opt.schedule_step())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -4,14 +4,11 @@ from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
# ScheduleItem = tuple[AST, buffer UOps, metadata, bound_ranges]
|
||||
ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], tuple[UOp, ...]]
|
||||
|
||||
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
|
||||
def _unwrap_src(s: UOp) -> UOp:
|
||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||
@@ -23,9 +20,8 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
for u in sched_sink.toposort(gate_kernel_sink):
|
||||
if u.op is Ops.RANGE: in_degree.setdefault(u, 0)
|
||||
if u.op is not Ops.AFTER: continue
|
||||
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
|
||||
k = u.src[1]
|
||||
assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
|
||||
in_degree.setdefault(k, 0)
|
||||
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
||||
@@ -50,52 +46,21 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
|
||||
with cpu_profile(TracingKey("linearize schedule")):
|
||||
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
|
||||
|
||||
schedule: list[UOp] = [] # RANGE, KERNEL, or END UOps
|
||||
sched_item: dict[UOp, ScheduleItem] = {}
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
while len(queue):
|
||||
k = rk = queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.CALL:
|
||||
ast = k.src[0]
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges)
|
||||
schedule.append(k)
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
rk = queue.popleft()
|
||||
k = rk.src[0] if rk.op is Ops.END else rk
|
||||
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
pre_schedule.append(ExecItem(k.src[0], [], k.arg.metadata))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
for x in children.get(rk, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
with cpu_profile(TracingKey("unroll outer ranges")):
|
||||
pre_schedule, buf_uops_list = unroll_outer_ranges(schedule, sched_item)
|
||||
return pre_schedule, UOp.sink(*buf_uops_list)
|
||||
|
||||
def unroll_outer_ranges(schedule:list[UOp], sched_item:dict[UOp, ScheduleItem]) -> tuple[list[ExecItem], list[UOp]]:
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]()
|
||||
while sched_ptr < len(schedule):
|
||||
si = schedule[sched_ptr]
|
||||
if si.op is Ops.RANGE:
|
||||
in_ranges[si] = 0
|
||||
range_ptrs[si] = sched_ptr + 1
|
||||
elif si.op is Ops.END:
|
||||
if in_ranges[si.src[1]] < si.src[1].vmax:
|
||||
in_ranges[si.src[1]] += 1
|
||||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
assert si.op is Ops.CALL, f"unexpected op in schedule: {si.op}"
|
||||
ast, buf_uops, metadata, bound_ranges = sched_item[si]
|
||||
fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
|
||||
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
sched_ptr += 1
|
||||
return pre_schedule, buf_uops_list
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
|
||||
@@ -39,7 +39,6 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
|
||||
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
||||
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
|
||||
(UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)),
|
||||
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
||||
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
||||
|
||||
@@ -26,8 +26,6 @@ pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize
|
||||
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize),
|
||||
# always realize REDUCE on outer ranges
|
||||
(UPat(Ops.REDUCE, name="r"), lambda ctx,r: realize(ctx, r) if any(tr.arg[-1] == AxisType.OUTER for tr in r.src[1:]) else None),
|
||||
# realize srcs of these
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs),
|
||||
# sometimes realize src of assign
|
||||
|
||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
@@ -363,18 +363,6 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
||||
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
|
||||
return ret
|
||||
|
||||
# lower outerworld reduce here
|
||||
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
|
||||
assert sdtype.addrspace == AddrSpace.GLOBAL
|
||||
outer_range = x.src[0].src[1]
|
||||
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
|
||||
# NOTE: this has the same number as the outer range, we need string ranges!
|
||||
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
|
||||
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
|
||||
bufi = buf.index(idx, dtype=sdtype)
|
||||
do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range)
|
||||
return buf.after(do_store)
|
||||
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
|
||||
@@ -447,9 +435,6 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
|
||||
|
||||
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
||||
if r.tag != (): return None
|
||||
if r.arg[-1] == AxisType.OUTER:
|
||||
# for outer range, we replace with a bound variable
|
||||
return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None))
|
||||
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
|
||||
ctx.range += 1
|
||||
return ret
|
||||
@@ -519,11 +504,8 @@ pm_add_range_tags = PatternMatcher([
|
||||
])
|
||||
|
||||
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
# if we have any non-outer ranges open here, we don't split
|
||||
if any(r.arg[-1] != AxisType.OUTER for r in x.ranges): return None
|
||||
|
||||
# ends of outer range don't go in kernels
|
||||
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
|
||||
# if we have any open ranges here, we don't split
|
||||
if x.ranges: return None
|
||||
|
||||
# local kernel rewrite
|
||||
lctx = LocalAddBufferContext()
|
||||
|
||||
@@ -16,16 +16,15 @@ if TYPE_CHECKING:
|
||||
class AxisType(Enum):
|
||||
def __repr__(self): return str(self)
|
||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702
|
||||
THREAD = auto(); PLACEHOLDER = auto() # noqa: E702
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
|
||||
AxisType.OUTER: "green"}
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
|
||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2}
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1}
|
||||
|
||||
|
||||
@@ -123,16 +123,9 @@ _tensor_spec = PatternMatcher([
|
||||
# REDUCE_AXIS is the reduce in the tensor graph
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
||||
|
||||
# REDUCE with an outerworld range
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
|
||||
# AFTER if things were kernelized
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||
|
||||
# Tensor range bind / store
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True),
|
||||
|
||||
# allow CALL/PARAM
|
||||
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
|
||||
(UPat(Ops.PARAM), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user