mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
delete multi output support (#8822)
* delete multioutput for now * test_schedule * test_assign too * linter * 515 for sd * update tests and ctx * update that assign check
This commit is contained in:
@@ -60,7 +60,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t, t2): return model(t, Tensor([801]), t2).realize()
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, params["ctx_dim"])), test, 18.0, 513)
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, params["ctx_dim"])), test, 18.0, 515)
|
||||
|
||||
def test_unet_resblock(self):
|
||||
model = [ResBlock(16, 24, 16) for _ in range(4)]
|
||||
@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65)
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 92)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
|
||||
def test_train_cifar(self):
|
||||
|
||||
@@ -283,6 +283,7 @@ class TestAssign(unittest.TestCase):
|
||||
#assert ba1 == ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
|
||||
@unittest.skip("multi output not supported anymore")
|
||||
def test_simple_assignment_multioutput(self):
|
||||
a = Tensor.randn(32, 32).realize()
|
||||
b = Tensor.full((32, ), 1.).contiguous().realize()
|
||||
@@ -321,6 +322,7 @@ class TestAssign(unittest.TestCase):
|
||||
b.assign(r + b.permute(1, 0))
|
||||
b.realize()
|
||||
|
||||
@unittest.skip("multi output not supported anymore")
|
||||
def test_permuted_reduceop_multioutput_dual_use(self):
|
||||
a = Tensor.randn(32, 32, 32).realize()
|
||||
b = Tensor.full((32, 32), 1.).contiguous().realize()
|
||||
@@ -333,6 +335,7 @@ class TestAssign(unittest.TestCase):
|
||||
c.assign(r + b_perm)
|
||||
Tensor.realize(b, c)
|
||||
|
||||
@unittest.skip("multi output not supported anymore")
|
||||
def test_permuted_reduceop_multioutput_dual_use_possible(self):
|
||||
a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize()
|
||||
b = Tensor.arange(32 * 32).reshape(32, 32).realize()
|
||||
|
||||
@@ -323,7 +323,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 11)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 11)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
@@ -682,6 +682,7 @@ class TestSchedule(unittest.TestCase):
|
||||
check_schedule(out, 2, filter_sink=False)
|
||||
|
||||
# multireduce spec
|
||||
@unittest.expectedFailure
|
||||
def test_reduce_same_size(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(4, 4).realize()
|
||||
@@ -694,6 +695,7 @@ class TestSchedule(unittest.TestCase):
|
||||
np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6)
|
||||
|
||||
# multireduce spec
|
||||
@unittest.expectedFailure
|
||||
def test_reduce_multiple_paths(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(4, 4).realize()
|
||||
@@ -714,7 +716,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out2 = b.sum().exp2()
|
||||
out3 = b.sum() + out2
|
||||
# run_schedule(check_schedule([out0, out1, out2, out3], 1))
|
||||
run_schedule(check_schedule([out0, out1, out2, out3], 2))
|
||||
run_schedule(check_schedule([out0, out1, out2, out3], 6))
|
||||
np.testing.assert_allclose(out0.numpy(), np_out0:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out1.numpy(), np_out1:=a.numpy().sum()+np_out0, atol=1e-4, rtol=1e-4)
|
||||
np_b = (a.numpy() + np_out0 + np_out1)
|
||||
@@ -793,6 +795,7 @@ class TestSchedule(unittest.TestCase):
|
||||
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_reduce_shrink_child(self):
|
||||
a = Tensor.empty(100, 100)
|
||||
b = Tensor.empty(10,)
|
||||
@@ -1039,7 +1042,7 @@ class TestSchedule(unittest.TestCase):
|
||||
_realize_weights(layer)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
|
||||
layer(x).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 10)
|
||||
check_schedule(opt.schedule_step(), 16)
|
||||
|
||||
def test_adam_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1049,7 +1052,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 10)
|
||||
check_schedule(opt.schedule_step(), 16)
|
||||
|
||||
def test_adam_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1060,7 +1063,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
check_schedule(opt.schedule_step(), 20)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1136,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
|
||||
shared = x.sum().half().float()
|
||||
a = shared * 2
|
||||
b = shared * 3
|
||||
sched = check_schedule([a, b], 1)
|
||||
sched = check_schedule([a, b], 3)
|
||||
for si in sched[:-2]: assert all(out.dtype == dtypes.half for out in si.outputs)
|
||||
|
||||
# reduce
|
||||
@@ -1272,6 +1275,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
# changed by: multireduce spec
|
||||
# pattern in adam
|
||||
@unittest.expectedFailure
|
||||
def test_partial_fuse3(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(16, 16).realize()
|
||||
@@ -1288,6 +1292,7 @@ class TestSchedule(unittest.TestCase):
|
||||
np.testing.assert_allclose(f.numpy(), b.numpy().sum() - e_np, atol=1e-4, rtol=1e-4)
|
||||
|
||||
# changed by: multireduce spec
|
||||
@unittest.expectedFailure
|
||||
def test_partial_fuse4(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(16, 16).realize()
|
||||
@@ -1763,6 +1768,7 @@ class TestIndexing(unittest.TestCase):
|
||||
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
|
||||
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_arange_fuse_grouped_children(self):
|
||||
X = Tensor.randn(4, 4).realize()
|
||||
r = (X+Tensor.arange(16).reshape(4, 4)).sum()
|
||||
@@ -1780,7 +1786,7 @@ class TestIndexing(unittest.TestCase):
|
||||
self.check_schedule([r], 1)
|
||||
np.testing.assert_allclose(r.numpy(), (X.numpy()+np.arange(16).reshape(4, 4)).sum(1, keepdims=True))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skip("multi output isn't supported")
|
||||
def test_multiview_arange_children(self):
|
||||
X = Tensor.randn(2,3,4,4).numpy()
|
||||
with Context(FUSE_ARANGE=1):
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers
|
||||
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
|
||||
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -204,12 +204,9 @@ to_si = PatternMatcher([
|
||||
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
])
|
||||
|
||||
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
|
||||
|
||||
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
||||
# remove movement ops + substitute LOAD of fused STORE with just the value
|
||||
sink = graph_rewrite(graph_rewrite(pre, multioutput+view_left, store_bufs:={x.buf_uop:x.src[2] for x in pre.src}), view_right)
|
||||
# apply swizzles (pushing views from the middle of the AST to BUFFER ops edges)
|
||||
sink = graph_rewrite(graph_rewrite(pre, view_left), view_right)
|
||||
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
||||
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals))
|
||||
# deal with ASSIGN
|
||||
@@ -222,7 +219,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
||||
if x.op is Ops.PRELOAD:
|
||||
assign_preloads[x.buf_uop] = None
|
||||
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
|
||||
if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous:
|
||||
if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
@@ -266,20 +263,6 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
|
||||
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
|
||||
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
|
||||
|
||||
def get_isolated_children(r:UOp, reduce_for_op:dict[UOp, UOp], children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp],
|
||||
realizes:dict[UOp, UOp], group:dict[UOp, None]) -> dict[UOp, None]:
|
||||
rc_parents, cache = deque(group), set()
|
||||
while rc_parents:
|
||||
if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue
|
||||
cache.add(p)
|
||||
# max one reduceop per kernel
|
||||
if p.op is Ops.REDUCE_AXIS: return {}
|
||||
rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
|
||||
# search descendants of the reduceop that can cleanly group
|
||||
descendants: dict[UOp, None] = {}
|
||||
for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
|
||||
def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
|
||||
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
@@ -296,8 +279,8 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
|
||||
can_chase = all(tr not in reduce_for_op for tr in group)
|
||||
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
||||
forced_realize = r in group
|
||||
if not forced_realize and len(group) > 1:
|
||||
group = get_isolated_children(r, reduce_for_op, ctx.children, ctx.allbufs, ctx.realizes, group)
|
||||
# can only have one output
|
||||
if not forced_realize and len(group) > 1: forced_realize = True
|
||||
# can only fuse assign if no other assign_target is used in the kernel
|
||||
if not forced_realize and any(x in ctx.assigns for x in group):
|
||||
parents = deque((r, *group))
|
||||
|
||||
Reference in New Issue
Block a user