mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
metal: fix graph when unrelated input buffers are not metal buffers (#10170)
* metal: fix graph when unrelated input buffers are not metal buffers * tinier test
This commit is contained in:
@@ -456,6 +456,22 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.numpy(), xc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(b.numpy(), yc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_jit_several_devs(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", "CPU"
|
||||
|
||||
def f(a, b):
|
||||
x = a.to(d0).realize()
|
||||
y = b.to(d0).realize()
|
||||
return x+y.realize(), x*y.realize()
|
||||
|
||||
jf = TinyJit(f)
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10, device=d1).realize()
|
||||
b = Tensor.randn(10, 10, device=d1).realize()
|
||||
zc, wc = jf(a, b)
|
||||
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
def test_jitted_view(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
|
||||
Reference in New Issue
Block a user