Test coverage for matvec (#2762)

* add test coverage for matvec

* skip devices that don't support locals
This commit is contained in:
qazal
2023-12-14 18:34:56 +02:00
committed by GitHub
parent 64fea9ff4a
commit 746cb5de21

View File

@@ -357,6 +357,22 @@ class TestHandCodedOpts(unittest.TestCase):
# check that we don't do too many upcasts
assert prod(k.full_shape[k.shape_len-k.upcasted:k.shape_len]) <= 49
def test_matvec(self):
if not Device[Device.DEFAULT].linearizer_opts.has_local:
self.skipTest("Only devices with locals")
N = 128
a = Tensor.rand(1, N).realize()
b = Tensor.rand(N, N).realize()
c = a @ b
s = c.lazydata.schedule()[0]
k = Linearizer(s.ast)
k.hand_coded_optimizations()
assert len(k.group_for_reduce) == 1
assert k.local_dims == 1
assert k.upcasted == 1
def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
wanna_output = None
realized_ast, real_bufs = helper_realized_ast(r)