diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 7aa3253a07..5fa29defc9 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -60,7 +60,7 @@ class TestRealWorld(unittest.TestCase): derandomize_model(model) @TinyJit def test(t, t2): return model(t, Tensor([801]), t2).realize() - helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, params["ctx_dim"])), test, 18.0, 513) + helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, params["ctx_dim"])), test, 18.0, 515) def test_unet_resblock(self): model = [ResBlock(16, 24, 16) for _ in range(4)] @@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase): loss.backward() optimizer.step() - helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65) + helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 92) @unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow") def test_train_cifar(self): diff --git a/test/test_assign.py b/test/test_assign.py index c8cb0d0a3b..1cb0b99d47 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -283,6 +283,7 @@ class TestAssign(unittest.TestCase): #assert ba1 == ba2 and ba1 != bb1 np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) + @unittest.skip("multi output not supported anymore") def test_simple_assignment_multioutput(self): a = Tensor.randn(32, 32).realize() b = Tensor.full((32, ), 1.).contiguous().realize() @@ -321,6 +322,7 @@ class TestAssign(unittest.TestCase): b.assign(r + b.permute(1, 0)) b.realize() + @unittest.skip("multi output not supported anymore") def test_permuted_reduceop_multioutput_dual_use(self): a = Tensor.randn(32, 32, 32).realize() b = Tensor.full((32, 32), 1.).contiguous().realize() @@ -333,6 +335,7 @@ class TestAssign(unittest.TestCase): c.assign(r + b_perm) Tensor.realize(b, c) + @unittest.skip("multi output not supported anymore") def test_permuted_reduceop_multioutput_dual_use_possible(self): a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize() b = Tensor.arange(32 * 32).reshape(32, 32).realize() diff --git a/test/test_schedule.py b/test/test_schedule.py index 670fb960c8..dbec8b3f90 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -323,7 +323,7 @@ class TestSchedule(unittest.TestCase): def test_fold_conv_batchnorm_optim(self): # this is too high - for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 11)]: + for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 11)]: with self.subTest(optim=optim.__name__): with Tensor.train(): img = Tensor.ones(1,3,4,4) @@ -682,6 +682,7 @@ class TestSchedule(unittest.TestCase): check_schedule(out, 2, filter_sink=False) # multireduce spec + @unittest.expectedFailure def test_reduce_same_size(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4).realize() @@ -694,6 +695,7 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6) # multireduce spec + @unittest.expectedFailure def test_reduce_multiple_paths(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4).realize() @@ -714,7 +716,7 @@ class TestSchedule(unittest.TestCase): out2 = b.sum().exp2() out3 = b.sum() + out2 # run_schedule(check_schedule([out0, out1, out2, out3], 1)) - run_schedule(check_schedule([out0, out1, out2, out3], 2)) + run_schedule(check_schedule([out0, out1, out2, out3], 6)) np.testing.assert_allclose(out0.numpy(), np_out0:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), np_out1:=a.numpy().sum()+np_out0, atol=1e-4, rtol=1e-4) np_b = (a.numpy() + np_out0 + np_out1) @@ -793,6 +795,7 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy(), atol=1e-4, rtol=1e-4) + @unittest.expectedFailure def test_reduce_shrink_child(self): a = Tensor.empty(100, 100) b = Tensor.empty(10,) @@ -1039,7 +1042,7 @@ class TestSchedule(unittest.TestCase): _realize_weights(layer) opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4) layer(x).relu().sum().backward() - check_schedule(opt.schedule_step(), 10) + check_schedule(opt.schedule_step(), 16) def test_adam_conv_fuse(self): with Tensor.train(): @@ -1049,7 +1052,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4) opt.zero_grad() c1(img).relu().sum().backward() - check_schedule(opt.schedule_step(), 10) + check_schedule(opt.schedule_step(), 16) def test_adam_2convs_fuse(self): with Tensor.train(): @@ -1060,7 +1063,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 14) + check_schedule(opt.schedule_step(), 20) def test_sgd_conv_fuse(self): with Tensor.train(): @@ -1136,7 +1139,7 @@ class TestSchedule(unittest.TestCase): shared = x.sum().half().float() a = shared * 2 b = shared * 3 - sched = check_schedule([a, b], 1) + sched = check_schedule([a, b], 3) for si in sched[:-2]: assert all(out.dtype == dtypes.half for out in si.outputs) # reduce @@ -1272,6 +1275,7 @@ class TestSchedule(unittest.TestCase): # changed by: multireduce spec # pattern in adam + @unittest.expectedFailure def test_partial_fuse3(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() @@ -1288,6 +1292,7 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(f.numpy(), b.numpy().sum() - e_np, atol=1e-4, rtol=1e-4) # changed by: multireduce spec + @unittest.expectedFailure def test_partial_fuse4(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() @@ -1763,6 +1768,7 @@ class TestIndexing(unittest.TestCase): loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())]) np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6) + @unittest.expectedFailure def test_arange_fuse_grouped_children(self): X = Tensor.randn(4, 4).realize() r = (X+Tensor.arange(16).reshape(4, 4)).sum() @@ -1780,7 +1786,7 @@ class TestIndexing(unittest.TestCase): self.check_schedule([r], 1) np.testing.assert_allclose(r.numpy(), (X.numpy()+np.arange(16).reshape(4, 4)).sum(1, keepdims=True)) - @unittest.expectedFailure + @unittest.skip("multi output isn't supported") def test_multiview_arange_children(self): X = Tensor.randn(2,3,4,4).numpy() with Context(FUSE_ARANGE=1): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 593eec57fe..981fb17f2b 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -3,7 +3,7 @@ from collections import defaultdict, deque from dataclasses import dataclass, field from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views -from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap +from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY from tinygrad.dtype import DType, ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker @@ -204,12 +204,9 @@ to_si = PatternMatcher([ (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), ]) -# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel -multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),]) - def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem: - # remove movement ops + substitute LOAD of fused STORE with just the value - sink = graph_rewrite(graph_rewrite(pre, multioutput+view_left, store_bufs:={x.buf_uop:x.src[2] for x in pre.src}), view_right) + # apply swizzles (pushing views from the middle of the AST to BUFFER ops edges) + sink = graph_rewrite(graph_rewrite(pre, view_left), view_right) # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals)) # deal with ASSIGN @@ -222,7 +219,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem: if x.op is Ops.PRELOAD: assign_preloads[x.buf_uop] = None # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD - if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous: + if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous: # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass # if it has a single view and it's equal when you shrink a contig, it's fine @@ -266,20 +263,6 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r) recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache) -def get_isolated_children(r:UOp, reduce_for_op:dict[UOp, UOp], children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], - realizes:dict[UOp, UOp], group:dict[UOp, None]) -> dict[UOp, None]: - rc_parents, cache = deque(group), set() - while rc_parents: - if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue - cache.add(p) - # max one reduceop per kernel - if p.op is Ops.REDUCE_AXIS: return {} - rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r) - # search descendants of the reduceop that can cleanly group - descendants: dict[UOp, None] = {} - for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={}) - return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) - def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: """search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop""" # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) @@ -296,8 +279,8 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: can_chase = all(tr not in reduce_for_op for tr in group) # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs forced_realize = r in group - if not forced_realize and len(group) > 1: - group = get_isolated_children(r, reduce_for_op, ctx.children, ctx.allbufs, ctx.realizes, group) + # can only have one output + if not forced_realize and len(group) > 1: forced_realize = True # can only fuse assign if no other assign_target is used in the kernel if not forced_realize and any(x in ctx.assigns for x in group): parents = deque((r, *group))