delete multi output support (#8822)

* delete multioutput for now

* test_schedule

* test_assign too

* linter

* 515 for sd

* update tests and ctx

* update that assign check
This commit is contained in:
qazal
2025-01-30 22:45:50 -05:00
committed by GitHub
parent 7647cd8428
commit 1fce864a6d
4 changed files with 24 additions and 32 deletions

View File

@@ -60,7 +60,7 @@ class TestRealWorld(unittest.TestCase):
derandomize_model(model)
@TinyJit
def test(t, t2): return model(t, Tensor([801]), t2).realize()
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, params["ctx_dim"])), test, 18.0, 513)
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, params["ctx_dim"])), test, 18.0, 515)
def test_unet_resblock(self):
model = [ResBlock(16, 24, 16) for _ in range(4)]
@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
loss.backward()
optimizer.step()
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65)
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 92)
@unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
def test_train_cifar(self):

View File

@@ -283,6 +283,7 @@ class TestAssign(unittest.TestCase):
#assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
@unittest.skip("multi output not supported anymore")
def test_simple_assignment_multioutput(self):
a = Tensor.randn(32, 32).realize()
b = Tensor.full((32, ), 1.).contiguous().realize()
@@ -321,6 +322,7 @@ class TestAssign(unittest.TestCase):
b.assign(r + b.permute(1, 0))
b.realize()
@unittest.skip("multi output not supported anymore")
def test_permuted_reduceop_multioutput_dual_use(self):
a = Tensor.randn(32, 32, 32).realize()
b = Tensor.full((32, 32), 1.).contiguous().realize()
@@ -333,6 +335,7 @@ class TestAssign(unittest.TestCase):
c.assign(r + b_perm)
Tensor.realize(b, c)
@unittest.skip("multi output not supported anymore")
def test_permuted_reduceop_multioutput_dual_use_possible(self):
a = Tensor.randn(32, 32, 32, dtype=dtypes.int).realize()
b = Tensor.arange(32 * 32).reshape(32, 32).realize()

View File

@@ -323,7 +323,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 11)]:
for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 11)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
@@ -682,6 +682,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 2, filter_sink=False)
# multireduce spec
@unittest.expectedFailure
def test_reduce_same_size(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
@@ -694,6 +695,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6)
# multireduce spec
@unittest.expectedFailure
def test_reduce_multiple_paths(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
@@ -714,7 +716,7 @@ class TestSchedule(unittest.TestCase):
out2 = b.sum().exp2()
out3 = b.sum() + out2
# run_schedule(check_schedule([out0, out1, out2, out3], 1))
run_schedule(check_schedule([out0, out1, out2, out3], 2))
run_schedule(check_schedule([out0, out1, out2, out3], 6))
np.testing.assert_allclose(out0.numpy(), np_out0:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), np_out1:=a.numpy().sum()+np_out0, atol=1e-4, rtol=1e-4)
np_b = (a.numpy() + np_out0 + np_out1)
@@ -793,6 +795,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy(), atol=1e-4, rtol=1e-4)
@unittest.expectedFailure
def test_reduce_shrink_child(self):
a = Tensor.empty(100, 100)
b = Tensor.empty(10,)
@@ -1039,7 +1042,7 @@ class TestSchedule(unittest.TestCase):
_realize_weights(layer)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 10)
check_schedule(opt.schedule_step(), 16)
def test_adam_conv_fuse(self):
with Tensor.train():
@@ -1049,7 +1052,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 10)
check_schedule(opt.schedule_step(), 16)
def test_adam_2convs_fuse(self):
with Tensor.train():
@@ -1060,7 +1063,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 14)
check_schedule(opt.schedule_step(), 20)
def test_sgd_conv_fuse(self):
with Tensor.train():
@@ -1136,7 +1139,7 @@ class TestSchedule(unittest.TestCase):
shared = x.sum().half().float()
a = shared * 2
b = shared * 3
sched = check_schedule([a, b], 1)
sched = check_schedule([a, b], 3)
for si in sched[:-2]: assert all(out.dtype == dtypes.half for out in si.outputs)
# reduce
@@ -1272,6 +1275,7 @@ class TestSchedule(unittest.TestCase):
# changed by: multireduce spec
# pattern in adam
@unittest.expectedFailure
def test_partial_fuse3(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
@@ -1288,6 +1292,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(f.numpy(), b.numpy().sum() - e_np, atol=1e-4, rtol=1e-4)
# changed by: multireduce spec
@unittest.expectedFailure
def test_partial_fuse4(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
@@ -1763,6 +1768,7 @@ class TestIndexing(unittest.TestCase):
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
@unittest.expectedFailure
def test_arange_fuse_grouped_children(self):
X = Tensor.randn(4, 4).realize()
r = (X+Tensor.arange(16).reshape(4, 4)).sum()
@@ -1780,7 +1786,7 @@ class TestIndexing(unittest.TestCase):
self.check_schedule([r], 1)
np.testing.assert_allclose(r.numpy(), (X.numpy()+np.arange(16).reshape(4, 4)).sum(1, keepdims=True))
@unittest.expectedFailure
@unittest.skip("multi output isn't supported")
def test_multiview_arange_children(self):
X = Tensor.randn(2,3,4,4).numpy()
with Context(FUSE_ARANGE=1):

View File

@@ -3,7 +3,7 @@ from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
from tinygrad.dtype import DType, ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -204,12 +204,9 @@ to_si = PatternMatcher([
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
])
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
# remove movement ops + substitute LOAD of fused STORE with just the value
sink = graph_rewrite(graph_rewrite(pre, multioutput+view_left, store_bufs:={x.buf_uop:x.src[2] for x in pre.src}), view_right)
# apply swizzles (pushing views from the middle of the AST to BUFFER ops edges)
sink = graph_rewrite(graph_rewrite(pre, view_left), view_right)
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals))
# deal with ASSIGN
@@ -222,7 +219,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
if x.op is Ops.PRELOAD:
assign_preloads[x.buf_uop] = None
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous:
if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
# if it has a single view and it's equal when you shrink a contig, it's fine
@@ -266,20 +263,6 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
def get_isolated_children(r:UOp, reduce_for_op:dict[UOp, UOp], children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp],
realizes:dict[UOp, UOp], group:dict[UOp, None]) -> dict[UOp, None]:
rc_parents, cache = deque(group), set()
while rc_parents:
if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue
cache.add(p)
# max one reduceop per kernel
if p.op is Ops.REDUCE_AXIS: return {}
rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
# search descendants of the reduceop that can cleanly group
descendants: dict[UOp, None] = {}
for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
@@ -296,8 +279,8 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
can_chase = all(tr not in reduce_for_op for tr in group)
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
if not forced_realize and len(group) > 1:
group = get_isolated_children(r, reduce_for_op, ctx.children, ctx.allbufs, ctx.realizes, group)
# can only have one output
if not forced_realize and len(group) > 1: forced_realize = True
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and any(x in ctx.assigns for x in group):
parents = deque((r, *group))