grouper tests cleanups [pr] (#9777)

* grouper tests cleanups [pr]

* viz

* tuple

* whitespace
This commit is contained in:
qazal
2025-04-08 12:33:11 +08:00
committed by GitHub
parent 4cc7422769
commit 9963bb51e0
2 changed files with 18 additions and 50 deletions

View File

@@ -11,7 +11,7 @@ from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp, view_left
from tinygrad.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites, merge_views, view_left
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
@@ -19,7 +19,6 @@ from tinygrad.engine.grouper import view_right, sym
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
remove_movement_ops = merge_views
def verify_ast(sink:UOp): return type_verify(list(sink.toposort), shape_spec)
class KernelCountException(Exception): pass
@@ -69,7 +68,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
@track_rewrites(named=True)
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {})
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, merge_views+sym, {})
class TestSchedule(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
@@ -623,7 +622,7 @@ class TestSchedule(unittest.TestCase):
def test_pow_const_tensor_to_zero(self):
x = Tensor([1,2,3,4])
out = x ** Tensor(0.0)
# NOTE: this is ConstBuffer 0 + ConstBuffer 1
# NOTE: this is UOp.const(0) + UOp.const(1)
check_schedule(out, 0)
def test_zero_size(self):
@@ -643,7 +642,6 @@ class TestSchedule(unittest.TestCase):
out = x.sum(1).relu().elu() + y.sum(1).relu().elu()
check_schedule(out, 2)
# multireduce spec
@unittest.skipUnless(SPLIT_REDUCEOP, "Testing split reducop requires SPLIT_REDUCEOP")
def test_preserve_multistage_reduce(self):
big_enough = getenv("REDUCEOP_SPLIT_THRESHOLD", 32768)
@@ -664,7 +662,6 @@ class TestSchedule(unittest.TestCase):
out = x.relu().sum(1) + out2[0]
check_schedule(out, 2)
# multireduce spec
@unittest.skip("these two Tensors are the same")
def test_example_matmul(self):
x = Tensor.eye(64, requires_grad=True)
@@ -712,9 +709,9 @@ class TestSchedule(unittest.TestCase):
x = x.sum(1)
x = x[:16]
out = x + y
check_schedule(out, 2) # TODO: this should be 1
# NOTE: this could be 1 kernel if we mask the store?
check_schedule(out, 2)
# multireduce spec
def test_multireduce_shrink(self):
Tensor.manual_seed(0)
a = Tensor.randn(32, 32).realize()
@@ -737,7 +734,6 @@ class TestSchedule(unittest.TestCase):
out = x.contiguous() + y.contiguous()
check_schedule(out, 2, filter_sink=False)
# multireduce spec
@unittest.expectedFailure
def test_reduce_same_size(self):
Tensor.manual_seed(0)
@@ -750,7 +746,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6)
# multireduce spec
@unittest.expectedFailure
def test_reduce_multiple_paths(self):
Tensor.manual_seed(0)
@@ -762,7 +757,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6)
# multireduce spec
def test_multireduce_reduce_multiple_paths(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
@@ -779,7 +773,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out2.numpy(), np_out2:=np.exp2(np_b.sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out3.numpy(), np_b.sum()+np_out2, atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_ext_reduce_child(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
@@ -792,7 +785,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_multiple_paths_midreduce(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
@@ -808,7 +800,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out1.numpy(), out1_np:=(a.numpy() - out0_np).max(), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out2.numpy(), r_np + out1_np, atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_multiple_paths_midreduce_fused(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
@@ -822,7 +813,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6)
# multireduce spec
def test_reduce_multiple_paths_midexpand(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
@@ -873,7 +863,6 @@ class TestSchedule(unittest.TestCase):
out1 = out0[0] + Tensor.empty(1, )
check_schedule([r, out0, out1], 3)
# multireduce spec
def test_std_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -881,7 +870,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_argmin_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -889,7 +877,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 3))
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
# multireduce spec
def test_argmax_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -909,7 +896,6 @@ class TestSchedule(unittest.TestCase):
compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy()))
np.testing.assert_allclose(out.numpy(), compare.numpy(), atol=1e-6, rtol=1e-3)
# multireduce spec
def test_ugly_reduceop_pairing(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
@@ -921,7 +907,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out.numpy(), \
(c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_expand_reduce_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
@@ -930,7 +915,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_reduce_expand_reduce_expand_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
@@ -940,7 +924,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out.numpy(), \
a.numpy()+(a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_branching_reduces_and_expands_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
@@ -951,7 +934,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out0.numpy(), a.numpy()+a.numpy().sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_simple_sequential(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -961,7 +943,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (y.numpy() + x.numpy().sum(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_simple_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -971,7 +952,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), y.numpy().sum(axis=-1) + x.numpy().sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_sequential(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -980,7 +960,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -990,7 +969,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 4))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1) + y.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_diffops_sequential(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -999,7 +977,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_diffops_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -1009,7 +986,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), x.numpy().sum(axis=-1) + y.numpy().max(axis=-1), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multireduce_fusion_sequential_and_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
@@ -1023,7 +999,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
# multireduce spec
def test_multimatmul_fusion(self):
Tensor.manual_seed(0)
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
@@ -1228,7 +1203,6 @@ class TestSchedule(unittest.TestCase):
schedule = check_schedule([b, c], 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_simple_chase(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4, 4).realize()
@@ -1252,7 +1226,6 @@ class TestSchedule(unittest.TestCase):
schedule = check_schedule([d, e], 3)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_push_permute_chase(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4, 4).realize()
@@ -1275,7 +1248,6 @@ class TestSchedule(unittest.TestCase):
schedule = check_schedule(d, 2)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD)
# multireduce spec
def test_multireduce_push_shrink_chase(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
@@ -1296,7 +1268,6 @@ class TestSchedule(unittest.TestCase):
schedule = check_schedule(b, 2)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
# multireduce spec
def test_multireduce_midreduce_nochase(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
@@ -1376,7 +1347,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
# multireduce spec
def test_multireduce_pad_reduce_safe(self):
Tensor.manual_seed(0)
a = Tensor.randn(3, 4, 5).realize()
@@ -1394,7 +1364,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
# multireduce spec
def test_multireduce_pad_reduce_unsafe(self):
Tensor.manual_seed(0)
a = Tensor.randn(3, 4, 5).abs().realize()
@@ -1471,11 +1440,10 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(tiny_ret, p)
def test_bitcast_fuses(self):
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.alu(Ops.EXP2).bitcast(dtypes.int32)
x = Tensor.empty(1, dtype=dtypes.float32)
a = x.exp2().bitcast(dtypes.int32)
b = x.bitcast(dtypes.int32)
b = a.alu(Ops.ADD, b)
check_schedule(b, 1) # this should fuse when it makes sense
check_schedule(a+b, 1) # this should fuse when it makes sense
@unittest.skip("disabling subbuffer manually isn't supported anymore")
def test_bitcast_disable_subbufer(self):
@@ -1503,7 +1471,6 @@ class TestSchedule(unittest.TestCase):
@unittest.skip("splitting kernels exceeding device buffer count is not yet supported")
def _test_buf_cnt(self, cnt:int, allowed:int):
#if (m:=BUF_LIMIT.get(Device.DEFAULT)) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}")
alu = functools.reduce(lambda x,y: x+y, [Tensor.ones((1, 1)).contiguous().realize() for _ in range(cnt-1)])
s = alu.schedule()
assert len(s) == allowed
@@ -1517,6 +1484,7 @@ class TestSchedule(unittest.TestCase):
@unittest.expectedFailure
def test_buf_cnt_over_limit_alt(self): self._test_buf_cnt(63, allowed=3)
@unittest.skipIf(getenv("VIZ"), "TODO: VIZ blocks gc")
def test_schedule_mem_used(self):
base = GlobalCounters.mem_used
Tensor.ones(256).contiguous().realize()
@@ -1648,7 +1616,7 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(xt, 6)
np.testing.assert_equal(xt.numpy(), 6)
@unittest.skip("TODO: support pads in graph_rewrite")
@unittest.skip("TODO: break the schedule if dims don't match")
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [1, 2]]
@@ -2060,7 +2028,7 @@ class TestView(unittest.TestCase):
run_schedule(s)
self.assertEqual(other_child.tolist(), [2, 3, 4])
def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, remove_movement_ops+symbolic_simple)
def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, merge_views+symbolic_simple)
class TestSimplifier(unittest.TestCase):
def test_sink_childless_const(self):
x = Tensor(0)
@@ -2302,13 +2270,13 @@ class TestTensorUOpSpec(unittest.TestCase):
unsafe_push_views = PatternMatcher([
(UPat.cvar("root").view(name="view"), lambda root,view: root.replace(src=tuple(x.view(view.st) for x in root.src))),
])
a.lazydata = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views+unsafe_push_views)
a.lazydata = graph_rewrite(a.lazydata.sink(), merge_views+merge_views+unsafe_push_views)
with self.assertRaisesRegex(RuntimeError, "UOp verification failed"):
a.schedule()
def test_expanded_const_ok(self):
a = Tensor.ones((4, 4))
t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views)
t = graph_rewrite(a.lazydata.sink(), merge_views+merge_views)
create_schedule_with_vars(t)
# NOTE: changing symbolic CONST VIEWs is not allowed
@@ -2316,7 +2284,7 @@ class TestTensorUOpSpec(unittest.TestCase):
def test_symbolic_shape_ok(self):
a = Tensor.ones(4)
vi = UOp.variable("i", 1, 10).bind(4)
a.lazydata = graph_rewrite(a.reshape(vi).sum().lazydata, remove_movement_ops+merge_views)
a.lazydata = graph_rewrite(a.reshape(vi).sum().lazydata, merge_views+merge_views)
a.schedule()
class TestBufferUOp(unittest.TestCase):
@@ -2348,7 +2316,7 @@ class TestBufferUOp(unittest.TestCase):
def test_buffer_view_not_allowed(self):
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
merged = graph_rewrite(permuted_view.lazydata, remove_movement_ops)
merged = graph_rewrite(permuted_view.lazydata, merge_views)
with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"):
merged.buffer # cannot access Buffer of a non contiguous VIEW