diff --git a/test/unit/test_metal_graph.py b/test/unit/test_metal_graph.py new file mode 100644 index 0000000000..7fa8dab4cd --- /dev/null +++ b/test/unit/test_metal_graph.py @@ -0,0 +1,26 @@ +import unittest +from unittest.mock import MagicMock +from tinygrad import Device +from tinygrad.engine.realize import CompiledRunner + +def _ei(*offsets): + ei = MagicMock() + ei.prg = MagicMock(spec=CompiledRunner) + ei.bufs = [None if o is None else MagicMock(**{"_buf.offset": o}) for o in offsets] + return ei + +@unittest.skipUnless(Device.DEFAULT == "METAL", "Metal device required to run") +class TestMetalGraph(unittest.TestCase): + def setUp(self): + from tinygrad.runtime.graph.metal import MetalGraph + self.MetalGraph = MetalGraph + self.dev = Device[Device.DEFAULT] + + def test_supports_exec_item_normal_offset(self): + assert self.MetalGraph.supports_exec_item([self.dev], _ei(0, 100, 0xFFFFFFFF)) is True + + def test_supports_exec_item_overflow_offset(self): + assert self.MetalGraph.supports_exec_item([], _ei(0, 0x100000000)) is False + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index 9105942228..a71e21955e 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -13,7 +13,6 @@ class MetalGraph(GraphRunner): def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], orig_valid_positions: dict[int, set[int]]|None = None): super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions) - if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException # create metal batch exec icb_descriptor = metal.MTLIndirectCommandBufferDescriptor.new() @@ -109,3 +108,9 @@ class MetalGraph(GraphRunner): if PROFILE and self.command_buffer is not None: wait_check(self.command_buffer) self.collect_timestamps() + + @staticmethod + def supports_exec_item(devs, ei:ExecItem) -> bool: + # Metal ICB replay encodes offsets as uint32; reject if any buffer offset exceeds 32-bit range. + if any(b is not None and b._buf.offset > 0xFFFFFFFF for b in ei.bufs): return False + return GraphRunner.supports_exec_item(devs, ei)