mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
metal: fix graph when unrelated input buffers are not metal buffers (#10170)
* metal: fix graph when unrelated input buffers are not metal buffers * tinier test
This commit is contained in:
@@ -456,6 +456,22 @@ class TestJit(unittest.TestCase):
|
|||||||
np.testing.assert_allclose(a.numpy(), xc.numpy(), atol=1e-4, rtol=1e-5)
|
np.testing.assert_allclose(a.numpy(), xc.numpy(), atol=1e-4, rtol=1e-5)
|
||||||
np.testing.assert_allclose(b.numpy(), yc.numpy(), atol=1e-4, rtol=1e-5)
|
np.testing.assert_allclose(b.numpy(), yc.numpy(), atol=1e-4, rtol=1e-5)
|
||||||
|
|
||||||
|
def test_jit_several_devs(self):
|
||||||
|
d0, d1 = f"{Device.DEFAULT}:0", "CPU"
|
||||||
|
|
||||||
|
def f(a, b):
|
||||||
|
x = a.to(d0).realize()
|
||||||
|
y = b.to(d0).realize()
|
||||||
|
return x+y.realize(), x*y.realize()
|
||||||
|
|
||||||
|
jf = TinyJit(f)
|
||||||
|
for _ in range(5):
|
||||||
|
a = Tensor.randn(10, 10, device=d1).realize()
|
||||||
|
b = Tensor.randn(10, 10, device=d1).realize()
|
||||||
|
zc, wc = jf(a, b)
|
||||||
|
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
|
||||||
|
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
|
||||||
|
|
||||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||||
def test_jitted_view(self):
|
def test_jitted_view(self):
|
||||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||||
|
|||||||
@@ -59,10 +59,9 @@ class MetalGraph(GraphRunner):
|
|||||||
self.range = to_struct(0, len(jit_cache))
|
self.range = to_struct(0, len(jit_cache))
|
||||||
|
|
||||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
||||||
|
|
||||||
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||||
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
|
|
||||||
|
|
||||||
|
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
|
||||||
for (j,i),input_idx in self.input_replace.items():
|
for (j,i),input_idx in self.input_replace.items():
|
||||||
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
||||||
msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
|
msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
|
||||||
|
|||||||
Reference in New Issue
Block a user