diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 6dacb5061f..91239ffb02 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -357,6 +357,22 @@ class TestHandCodedOpts(unittest.TestCase): # check that we don't do too many upcasts assert prod(k.full_shape[k.shape_len-k.upcasted:k.shape_len]) <= 49 + def test_matvec(self): + if not Device[Device.DEFAULT].linearizer_opts.has_local: + self.skipTest("Only devices with locals") + N = 128 + a = Tensor.rand(1, N).realize() + b = Tensor.rand(N, N).realize() + c = a @ b + + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.hand_coded_optimizations() + + assert len(k.group_for_reduce) == 1 + assert k.local_dims == 1 + assert k.upcasted == 1 + def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False): wanna_output = None realized_ast, real_bufs = helper_realized_ast(r)